Skip to content

Commit

Permalink
Fix loaders for M5 & ETT datasets (#3155)
Browse files Browse the repository at this point in the history
*Description of changes:*
- Fix how `item_id` is obtained for M5 and ETT datasets
- Fix `lxml` dependency range


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
shchur authored Apr 2, 2024
1 parent e74bbc5 commit 739627a
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ filterwarnings = "ignore"
[tool.ruff]
line-length = 79

ignore = [
lint.ignore = [
# line-length is handled by black
"E501",

Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ipykernel~=6.5
nbconvert~=6.5.1
nbsphinx~=0.8.8
notedown
lxml~=5.1.0
pytest-runner~=2.11
recommonmark
sphinx~=4.0
Expand Down
8 changes: 4 additions & 4 deletions src/gluonts/dataset/repository/_ett_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,29 +39,29 @@ def generate_ett_small_dataset(
dfs.append(df)

test = []
for df in dfs:
for region, df in enumerate(dfs):
start = pd.Period(df["date"][0], freq=freq)
for col in df.columns:
if col in ["date"]:
continue
test.append(
{
"start": start,
"item_id": col,
"item_id": f"{col}_{region}",
"target": df[col].values,
}
)

train = []
for df in dfs:
for region, df in enumerate(dfs):
start = pd.Period(df["date"][0], freq=freq)
for col in df.columns:
if col in ["date"]:
continue
train.append(
{
"start": start,
"item_id": col,
"item_id": f"{col}_{region}",
"target": df[col].values[:-prediction_length],
}
)
Expand Down
6 changes: 5 additions & 1 deletion src/gluonts/dataset/repository/_m5.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ def generate_m5_dataset(
]

# Build target series
train_ids = sales_train_validation["item_id"]
train_ids = (
sales_train_validation["item_id"].str
+ "_"
+ sales_train_validation["store_id"].str
)
train_df = sales_train_validation.drop(
["id", "item_id", "dept_id", "cat_id", "store_id", "state_id"],
axis=1,
Expand Down
3 changes: 2 additions & 1 deletion src/gluonts/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
character if set to `True`.
"""

__all__ = [ # noqa
# ruff: noqa: F822
__all__ = [
"variant",
"dump",
"dumps",
Expand Down

0 comments on commit 739627a

Please sign in to comment.