Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
koenvo committed Jun 10, 2024
1 parent 670edaa commit 473037a
Show file tree
Hide file tree
Showing 20 changed files with 379 additions and 90 deletions.
70 changes: 67 additions & 3 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
overload,
Iterable,
NamedTuple,
Tuple,
)


if sys.version_info >= (3, 8):
from typing import Literal
else:
Expand Down Expand Up @@ -128,6 +128,10 @@ class Position:
def __str__(self):
return self.name

@classmethod
def unknown(cls) -> "Position":
return cls(position_id="", name="Unknown")


@dataclass(frozen=True)
class Player:
Expand All @@ -152,8 +156,9 @@ class Player:
last_name: str = None

# match specific
starting: bool = None
position: Position = None
positions: TimeContainer[Position] = field(
default_factory=TimeContainer, compare=False
)

attributes: Optional[Dict] = field(default_factory=dict, compare=False)

Expand All @@ -165,6 +170,25 @@ def full_name(self):
return f"{self.first_name} {self.last_name}"
return f"{self.team.ground}_{self.jersey_no}"

@property
def position(self) -> Optional[Position]:
try:
return self.positions.last()
except KeyError:
return None

@property
def starting(self) -> bool:
"""Return if the player has a position at the beginning of the match."""
return self.starting_position is not None

@property
def starting_position(self):
try:
return self.positions.at_start()
except KeyError:
return None

def __str__(self):
return self.full_name

Expand All @@ -176,6 +200,46 @@ def __eq__(self, other):
return False
return self.player_id == other.player_id

@classmethod
def build(
cls,
player_id: str,
team: "Team",
jersey_no: Optional[int],
name: str = None,
first_name: str = None,
last_name: str = None,
starting_position: Optional[Position] = None,
periods: Optional[List[Period]] = None,
attributes: Optional[dict] = None,
):

if attributes is None:
attributes = {}

positions = TimeContainer()
if starting_position:
if not periods:
raise KloppyError(
"You must pass periods when using starting_position"
)

positions.set(periods[0].start_time, starting_position)

return cls(
player_id=player_id,
team=team,
jersey_no=jersey_no,
name=name,
first_name=first_name,
last_name=last_name,
positions=positions,
attributes=attributes,
)

def set_position(self, time: Time, position: Optional[Position]):
self.positions.set(time, position)


@dataclass
class Team:
Expand Down
17 changes: 16 additions & 1 deletion kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
Callable,
Optional,
TYPE_CHECKING,
NamedTuple,
)

from kloppy.domain.models.common import (
DatasetType,
AttackingDirection,
OrientationError,
Position,
)
from kloppy.utils import (
camelcase_to_snakecase,
Expand All @@ -29,7 +31,7 @@
from .formation import FormationType
from .pitch import Point

from ...exceptions import OrphanedRecordError, InvalidFilterError
from ...exceptions import OrphanedRecordError, InvalidFilterError, KloppyError

if TYPE_CHECKING:
from .tracking import Frame
Expand Down Expand Up @@ -879,6 +881,7 @@ class SubstitutionEvent(Event):
"""

replacement_player: Player
position: Optional[Position] = None

event_type: EventType = EventType.SUBSTITUTION
event_name: str = "substitution"
Expand Down Expand Up @@ -1113,6 +1116,18 @@ def generic_record_converter(event: Event):
map(generic_record_converter, self.records)
)

def aggregate(self, type_: str) -> List[Any]:
if type_ == "minutes_played":
from kloppy.domain.services.aggregators.minutes_played import (
MinutesPlayedAggregator,
)

aggregator = MinutesPlayedAggregator()
else:
raise KloppyError(f"No aggregator {type_} not found")

return aggregator.aggregate(self)


__all__ = [
"EnumQualifier",
Expand Down
82 changes: 55 additions & 27 deletions kloppy/domain/models/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
List,
Tuple,
NamedTuple,
Literal,
)

from sortedcontainers import SortedList
from sortedcontainers import SortedDict

from kloppy.exceptions import KloppyError

Expand Down Expand Up @@ -51,11 +52,13 @@ def contains(self, timestamp: datetime):

@property
def start_time(self) -> "Time":
return Time(period=self, timestamp=self.start_timestamp)
return Time(period=self, timestamp=timedelta(0))

@property
def end_time(self) -> "Time":
return Time(period=self, timestamp=self.end_timestamp)
return Time(
period=self, timestamp=self.end_timestamp - self.start_timestamp
)

@property
def duration(self) -> timedelta:
Expand Down Expand Up @@ -94,6 +97,17 @@ class Time:
period: "Period"
timestamp: timedelta

@classmethod
def from_period(
cls,
period: Period,
type_: Union[Literal["start"], Literal["end"]] = "start",
):
return cls(
period=period,
timestamp=timedelta(0) if type_ == "start" else period.duration,
)

@overload
def __sub__(self, other: timedelta) -> "Time":
...
Expand Down Expand Up @@ -178,46 +192,60 @@ def __str__(self):
m, s = divmod(self.timestamp.total_seconds(), 60)
return f"P{self.period.id}T{m:02.0f}:{s:02.0f}"


T = TypeVar("T")
def __hash__(self):
return hash((self.period.id, self.timestamp.total_seconds()))


class Pair(NamedTuple):
key: Time
item: T
T = TypeVar("T")


class TimeContainer(Generic[T]):
def __init__(self):
self.items: SortedList = SortedList(key=lambda pair: pair.key)
self.items: SortedDict = SortedDict()

def add(self, time: Time, item: T):
self.items.add(Pair(key=time, item=item))
def set(self, time: Time, item: Optional[T]):
self.items[time] = item # Pair(key=time, item=item)

def value_at(self, time: Time) -> T:
idx = self.items.bisect_left(Pair(key=time, item=None)) - 1
def value_at(self, time: Time) -> Optional[T]:
idx = self.items.bisect_right(time) - 1
if idx < 0:
raise ValueError("Not found")
return self.items[idx].item
raise KeyError("Not found")
return self.items.values()[idx]

def __getitem__(self, item: Time):
return self.value_at(item)

def __setitem__(self, key: Time, value: Optional[T]):
self.set(key, value)

def ranges(self, add_end: bool = True) -> List[Tuple[Time, Time, T]]:
def ranges(self) -> List[Tuple[Time, Time, T]]:
items = list(self.items)
if not items:
return []

if add_end:
items.append(
Pair(
# Ugly way to get us to the end of the last period
key=items[0].key + timedelta(seconds=10_000_000),
item=None,
)
)

if len(items) < 2:
raise ValueError("Cannot create ranges when length < 2")

ranges_ = []
for start_pair, end_pair in zip(items[:-1], items[1:]):
ranges_.append((start_pair.key, end_pair.key, start_pair.item))
for start_time, end_time in zip(items[:-1], items[1:]):
ranges_.append((start_time, end_time, self.items[start_time]))
return ranges_

def last(self):
if not len(self.items):
raise KeyError

return self.items[self.items.keys()[-1]]

def at_start(self):
"""Return the value at the beginning of the match"""
if not self.items:
raise KeyError

first_item: Time = self.items.keys()[0]

tmp_period = first_item.period
while tmp_period.prev_period:
tmp_period = tmp_period.prev_period

return self.value_at(Time.from_period(tmp_period, "start"))
10 changes: 10 additions & 0 deletions kloppy/domain/services/aggregators/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
from typing import Dict, Any, Hashable, List, NamedTuple

from kloppy.domain import EventDataset


class EventDatasetAggregator(ABC):
@abstractmethod
def aggregate(self, dataset: EventDataset) -> List[NamedTuple]:
raise NotImplementedError
72 changes: 72 additions & 0 deletions kloppy/domain/services/aggregators/minutes_played.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from datetime import timedelta
from typing import Dict, List, NamedTuple, Union, Optional

from kloppy.domain import EventDataset, Player, Position, Time
from kloppy.domain.services.aggregators.aggregator import (
EventDatasetAggregator,
)


class MinutesPlayed(NamedTuple):
player: Player
start_time: Time
end_time: Time
duration: timedelta


class MinutesPlayedPerPosition(NamedTuple):
player: Player
position: Position
start_time: Time
end_time: Time
duration: timedelta


class MinutesPlayedAggregator(EventDatasetAggregator):
def __init__(self, aggregate_position: bool = True):
self.aggregate_position = aggregate_position

def aggregate(
self, dataset: EventDataset
) -> List[Union[MinutesPlayedPerPosition, MinutesPlayed]]:
items = []

for team in dataset.metadata.teams:
for player in team.players:
if self.aggregate_position:
_start_time = None
end_time = None
for (
start_time,
end_time,
position,
) in player.positions.ranges():
if not _start_time:
_start_time = start_time

if _start_time:
items.append(
MinutesPlayed(
player=player,
start_time=_start_time,
end_time=_start_time,
duration=end_time - _start_time,
)
)
else:
for (
start_time,
end_time,
position,
) in player.positions.ranges():
items.append(
MinutesPlayedPerPosition(
player=player,
position=position,
start_time=start_time,
end_time=end_time,
duration=end_time - start_time,
)
)

return items
Loading

0 comments on commit 473037a

Please sign in to comment.