diff --git a/slider/beatmap.py b/slider/beatmap.py index a0b95e9..5a76beb 100644 --- a/slider/beatmap.py +++ b/slider/beatmap.py @@ -1680,6 +1680,60 @@ def _resolve_stacking_old(self, hit_objects, ar, cs): return hit_objects + @lazyval + def _hit_object_times(self): + """a (sorted) list of hitobject time's, so they can be searched with + ``np.searchsorted`` + """ + return [hitobj.time for hitobj in self._hit_objects] + + def closest_hitobject(self, t, side="left"): + """The hitobject closest in time to ``t``. + + Parameters + ---------- + t : datetime.timedelta + The time to find the hitobject closest to. + side : {"left", "right"} + Whether to prefer the earlier (left) or later (right) hitobject + when breaking ties. + + Returns + ------- + hit_object : HitObject + The closest hitobject in time to ``t``. + None + If the beatmap has no hitobjects. + """ + if len(self._hit_objects) == 0: + raise ValueError(f"The beatmap {self!r} must have at least one " + "hit object to determine the closest hitobject.") + if len(self._hit_objects) == 1: + return self._hit_objects[0] + + i = np.searchsorted(self._hit_object_times, t) + # if ``t`` is after the last hitobject, an index of + # len(self._hit_objects) will be returned. The last hitobject will + # always be the closest hitobject in this case. + if i == len(self._hit_objects): + return self._hit_objects[-1] + # similar logic follows for the first hitobject. + if i == 0: + return self._hit_objects[0] + + # searchsorted tells us the two closest hitobjects, but not which is + # closer. Check both candidates. + hitobj1 = self._hit_objects[i - 1] + hitobj2 = self._hit_objects[i] + dist1 = abs(hitobj1.time - t) + dist2 = abs(hitobj2.time - t) + + hitobj1_closer = dist1 <= dist2 if side == "left" else dist1 < dist1 + + if hitobj1_closer: + return hitobj1 + return hitobj2 + @lazyval def max_combo(self): """The highest combo that can be achieved on this beatmap. diff --git a/slider/tests/test_beatmap.py b/slider/tests/test_beatmap.py index d8f4adb..ec7c344 100644 --- a/slider/tests/test_beatmap.py +++ b/slider/tests/test_beatmap.py @@ -186,6 +186,26 @@ def test_hit_objects_hard_rock(beatmap): Position(x=301, y=209)] +def test_closest_hitobject(): + beatmap = slider.example_data.beatmaps.miiro_vs_ai_no_scenario('Beginner') + hit_object1 = beatmap.hit_objects()[4] + hit_object2 = beatmap.hit_objects()[5] + hit_object3 = beatmap.hit_objects()[6] + + middle_t = timedelta(milliseconds=11076 - ((11076 - 9692) / 2)) + + assert hit_object1.time == timedelta(milliseconds=8615) + assert hit_object2.time == timedelta(milliseconds=9692) + assert hit_object3.time == timedelta(milliseconds=11076) + + assert beatmap.closest_hitobject(timedelta(milliseconds=8615)) == \ + hit_object1 + assert beatmap.closest_hitobject(timedelta(milliseconds=(8615 - 30))) == \ + hit_object1 + assert beatmap.closest_hitobject(middle_t) == hit_object2 + assert beatmap.closest_hitobject(middle_t, side="right") == hit_object3 + + def test_ar(beatmap): assert beatmap.ar() == 9.5