Skip to content

Commit

Permalink
Merge pull request #294 from probberechts/fix/timestamps
Browse files Browse the repository at this point in the history
Uniform implementation of timestamps
  • Loading branch information
koenvo committed Apr 3, 2024
2 parents cd3d116 + dc90078 commit e678d62
Show file tree
Hide file tree
Showing 36 changed files with 1,667 additions and 258 deletions.
3 changes: 2 additions & 1 deletion kloppy/domain/models/code.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import timedelta
from dataclasses import dataclass, field
from typing import List, Dict, Callable, Union, Any

Expand Down Expand Up @@ -26,7 +27,7 @@ class Code(DataRecord):

code_id: str
code: str
end_timestamp: float
end_timestamp: timedelta
labels: Dict[str, Union[bool, str]] = field(default_factory=dict)

@property
Expand Down
35 changes: 24 additions & 11 deletions kloppy/domain/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field, replace
from datetime import datetime, timedelta
from enum import Enum, Flag
from typing import (
Dict,
Expand Down Expand Up @@ -34,6 +35,7 @@
OrientationError,
InvalidFilterError,
KloppyParameterError,
KloppyError,
)


Expand Down Expand Up @@ -238,21 +240,32 @@ class Period:
Period
Attributes:
id: `1` for first half, `2` for second half, `3` for first overtime,
`4` for second overtime, and `5` for penalty shootouts
start_timestamp: timestamp given by provider (can be unix timestamp or relative)
end_timestamp: timestamp given by provider (can be unix timestamp or relative)
id: `1` for first half, `2` for second half, `3` for first half of
overtime, `4` for second half of overtime, `5` for penalty shootout
start_timestamp: The UTC datetime of the kick-off or, if the
absolute datetime is not available, the offset between the start
of the data feed and the period's kick-off
end_timestamp: The UTC datetime of the final whistle or, if the
absolute datetime is not available, the offset between the start
of the data feed and the period's final whistle
attacking_direction: See [`AttackingDirection`][kloppy.domain.models.common.AttackingDirection]
"""

id: int
start_timestamp: float
end_timestamp: float
start_timestamp: Union[datetime, timedelta]
end_timestamp: Union[datetime, timedelta]

def contains(self, timestamp: float):
return self.start_timestamp <= timestamp <= self.end_timestamp
def contains(self, timestamp: datetime):
if isinstance(self.start_timestamp, datetime) and isinstance(
self.end_timestamp, datetime
):
return self.start_timestamp <= timestamp <= self.end_timestamp
raise KloppyError(
"This method can only be used when start_timestamp and end_timestamp are a datetime"
)

@property
def duration(self):
def duration(self) -> timedelta:
return self.end_timestamp - self.start_timestamp

def __eq__(self, other):
Expand Down Expand Up @@ -811,7 +824,7 @@ class DataRecord(ABC):
Attributes:
period: See [`Period`][kloppy.domain.models.common.Period]
timestamp: Timestamp of occurrence
timestamp: Timestamp of occurrence, relative to the period kick-off
ball_owning_team: See [`Team`][kloppy.domain.models.common.Team]
ball_state: See [`Team`][kloppy.domain.models.common.BallState]
"""
Expand All @@ -820,7 +833,7 @@ class DataRecord(ABC):
prev_record: Optional["DataRecord"] = field(init=False)
next_record: Optional["DataRecord"] = field(init=False)
period: Period
timestamp: float
timestamp: timedelta
ball_owning_team: Optional[Team]
ball_state: Optional[BallState]

Expand Down
2 changes: 1 addition & 1 deletion kloppy/domain/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def matches(self, filter_) -> bool:
return True

def __str__(self):
m, s = divmod(self.timestamp, 60)
m, s = divmod(self.timestamp.total_seconds(), 60)

event_type = (
self.__class__.__name__
Expand Down
23 changes: 17 additions & 6 deletions kloppy/infra/serializers/code/sportscode.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from datetime import timedelta
from typing import Union, IO, NamedTuple

from lxml import objectify, etree
Expand Down Expand Up @@ -50,15 +51,19 @@ def deserialize(self, inputs: SportsCodeInputs) -> CodeDataset:
all_instances = objectify.fromstring(inputs.data.read())

codes = []
period = Period(id=1, start_timestamp=0, end_timestamp=0)
period = Period(
id=1,
start_timestamp=timedelta(seconds=0),
end_timestamp=timedelta(seconds=0),
)
for instance in all_instances.ALL_INSTANCES.iterchildren():
end_timestamp = float(instance.end)
end_timestamp = timedelta(seconds=float(instance.end))

code = Code(
period=period,
code_id=str(instance.ID),
code=str(instance.code),
timestamp=float(instance.start),
timestamp=timedelta(seconds=float(instance.start)),
end_timestamp=end_timestamp,
labels=parse_labels(instance),
ball_state=None,
Expand Down Expand Up @@ -88,7 +93,7 @@ def serialize(self, dataset: CodeDataset) -> bytes:
root = etree.Element("file")
all_instances = etree.SubElement(root, "ALL_INSTANCES")
for i, code in enumerate(dataset.codes):
relative_period_start = 0
relative_period_start = timedelta(seconds=0)
for period in dataset.metadata.periods:
if period == code.period:
break
Expand All @@ -100,10 +105,16 @@ def serialize(self, dataset: CodeDataset) -> bytes:
id_.text = code.code_id or str(i + 1)

start = etree.SubElement(instance, "start")
start.text = str(relative_period_start + code.start_timestamp)
start.text = str(
relative_period_start.total_seconds()
+ code.start_timestamp.total_seconds()
)

end = etree.SubElement(instance, "end")
end.text = str(relative_period_start + code.end_timestamp)
end.text = str(
relative_period_start.total_seconds()
+ code.end_timestamp.total_seconds()
)

code_ = etree.SubElement(instance, "code")
code_.text = code.code
Expand Down
57 changes: 46 additions & 11 deletions kloppy/infra/serializers/event/datafactory/deserializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import logging
from datetime import timedelta, datetime, timezone
from dataclasses import replace
from typing import Dict, List, Tuple, Union, IO, NamedTuple

from kloppy.domain import (
Expand Down Expand Up @@ -155,8 +157,10 @@
DF_EVENT_TYPE_PENALTY_SHOOTOUT_POST = 183


def parse_str_ts(raw_event: Dict) -> float:
return raw_event["t"]["m"] * 60 + (raw_event["t"]["s"] or 0)
def parse_str_ts(raw_event: Dict) -> timedelta:
return timedelta(
seconds=raw_event["t"]["m"] * 60 + (raw_event["t"]["s"] or 0)
)


def _parse_coordinates(coordinates: Dict[str, float]) -> Point:
Expand Down Expand Up @@ -397,8 +401,21 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
# setup periods
status = incidences.pop(DF_EVENT_CLASS_STATUS)
# start timestamps are fixed
start_ts = {1: 0, 2: 45 * 60, 3: 90 * 60, 4: 105 * 60, 5: 120 * 60}
start_ts = {
1: timedelta(minutes=0),
2: timedelta(minutes=45),
3: timedelta(minutes=90),
4: timedelta(minutes=105),
5: timedelta(minutes=120),
}
# check for end status updates to setup periods
start_event_types = {
DF_EVENT_TYPE_STATUS_MATCH_START,
DF_EVENT_TYPE_STATUS_SECOND_HALF_START,
DF_EVENT_TYPE_STATUS_FIRST_EXTRA_START,
DF_EVENT_TYPE_STATUS_SECOND_EXTRA_START,
DF_EVENT_TYPE_STATUS_PENALTY_SHOOTOUT_START,
}
end_event_types = {
DF_EVENT_TYPE_STATUS_MATCH_END,
DF_EVENT_TYPE_STATUS_FIRST_HALF_END,
Expand All @@ -408,15 +425,33 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
}
periods = {}
for status_update in status.values():
if status_update["type"] not in end_event_types:
if status_update["type"] not in (
start_event_types | end_event_types
):
continue
timestamp = datetime.strptime(
match["date"]
+ status_update["time"]
+ match["stadiumGMT"],
"%Y%m%d%H:%M:%S%z",
).astimezone(timezone.utc)
half = status_update["t"]["half"]
end_ts = parse_str_ts(status_update)
periods[half] = Period(
id=half,
start_timestamp=start_ts[half],
end_timestamp=end_ts,
)
if status_update["type"] == DF_EVENT_TYPE_STATUS_MATCH_START:
half = 1
if status_update["type"] in start_event_types:
periods[half] = Period(
id=half,
start_timestamp=timestamp,
end_timestamp=None,
)
elif status_update["type"] in end_event_types:
if half not in periods:
raise DeserializationError(
f"Missing start event for period {half}"
)
periods[half] = replace(
periods[half], end_timestamp=timestamp
)

# exclude goals, already listed as shots too
incidences.pop(DF_EVENT_CLASS_GOALS)
Expand Down Expand Up @@ -444,7 +479,7 @@ def deserialize(self, inputs: DatafactoryInputs) -> EventDataset:
# skip invalid event
continue

timestamp = parse_str_ts(raw_event)
timestamp = parse_str_ts(raw_event) - start_ts[period.id]
if (
previous_event is not None
and previous_event["t"]["half"] != raw_event["t"]["half"]
Expand Down
32 changes: 23 additions & 9 deletions kloppy/infra/serializers/event/metrica/json_deserializer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import json
from dataclasses import replace
from datetime import timedelta
from typing import Dict, List, NamedTuple, IO, Optional

from kloppy.domain import (
Expand All @@ -10,6 +11,7 @@
CarryResult,
EventDataset,
PassResult,
Period,
Point,
Provider,
Qualifier,
Expand Down Expand Up @@ -106,15 +108,21 @@ def _parse_subtypes(event: dict) -> List:


def _parse_pass(
event: Dict, previous_event: Dict, subtypes: List, team: Team
period: Period,
event: Dict,
previous_event: Dict,
subtypes: List,
team: Team,
) -> Dict:
event_type_id = event["type"]["id"]

if event_type_id == MS_PASS_OUTCOME_COMPLETE:
result = PassResult.COMPLETE
receiver_player = team.get_player_by_id(event["to"]["id"])
receiver_coordinates = _parse_coordinates(event["end"])
receive_timestamp = event["end"]["time"]
receive_timestamp = (
timedelta(seconds=event["end"]["time"]) - period.start_timestamp
)
else:
if event_type_id == MS_PASS_OUTCOME_OUT:
result = PassResult.OUT
Expand Down Expand Up @@ -208,11 +216,12 @@ def _parse_shot(event: Dict, previous_event: Dict, subtypes: List) -> Dict:
return dict(result=result, qualifiers=qualifiers)


def _parse_carry(event: Dict) -> Dict:
def _parse_carry(period: Period, event: Dict) -> Dict:
return dict(
result=CarryResult.COMPLETE,
end_coordinates=_parse_coordinates(event["end"]),
end_timestamp=event["end"]["time"],
end_timestamp=timedelta(seconds=event["end"]["time"])
- period.start_timestamp,
)


Expand Down Expand Up @@ -285,7 +294,8 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset:
generic_event_kwargs = dict(
# from DataRecord
period=period,
timestamp=raw_event["start"]["time"],
timestamp=timedelta(seconds=raw_event["start"]["time"])
- period.start_timestamp,
ball_owning_team=_parse_ball_owning_team(event_type, team),
ball_state=BallState.ALIVE,
# from Event
Expand All @@ -301,6 +311,7 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset:
continue
elif event_type in MS_PASS_TYPES:
pass_event_kwargs = _parse_pass(
period=period,
event=raw_event,
previous_event=previous_event,
subtypes=subtypes,
Expand Down Expand Up @@ -332,7 +343,9 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset:
)

elif event_type == MS_EVENT_TYPE_CARRY:
carry_event_kwargs = _parse_carry(event=raw_event)
carry_event_kwargs = _parse_carry(
period=period, event=raw_event
)
event = self.event_factory.build_carry(
qualifiers=None,
**carry_event_kwargs,
Expand Down Expand Up @@ -371,9 +384,10 @@ def deserialize(self, inputs: MetricaJsonEventDataInputs) -> EventDataset:
generic_event_kwargs[
"coordinates"
] = _parse_coordinates(raw_event["end"])
generic_event_kwargs["timestamp"] = raw_event["end"][
"time"
]
generic_event_kwargs["timestamp"] = (
timedelta(seconds=raw_event["end"]["time"])
- period.start_timestamp
)

event = self.event_factory.build_ball_out(
result=None,
Expand Down
8 changes: 3 additions & 5 deletions kloppy/infra/serializers/event/opta/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,14 @@
}


def _parse_f24_datetime(dt_str: str) -> float:
def _parse_f24_datetime(dt_str: str) -> datetime:
def zero_pad_milliseconds(timestamp):
parts = timestamp.split(".")
return ".".join(parts[:-1] + ["{:03d}".format(int(parts[-1]))])

dt_str = zero_pad_milliseconds(dt_str)
return (
datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S.%f")
.replace(tzinfo=pytz.utc)
.timestamp()
return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S.%f").replace(
tzinfo=pytz.utc
)


Expand Down
Loading

0 comments on commit e678d62

Please sign in to comment.