Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
mivanit committed Aug 22, 2024
1 parent 98a2083 commit 85ee878
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 13 deletions.
10 changes: 0 additions & 10 deletions maze_transformer/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,25 +589,17 @@ def get_config_multisource(
dataset_cfg_name, model_cfg_name, train_cfg_name = cfg_names
# assemble the collective name
name = f"multsrc_{dataset_cfg_name}_{model_cfg_name}_{train_cfg_name}"
print(f"gcm 591: {dataset_cfg_name = }")
else:
# 4 names if collective name, unpack it
dataset_cfg_name, model_cfg_name, train_cfg_name, name = cfg_names
print(f"gcm 595: {dataset_cfg_name = }")

try:
# try to actually assemble the configuration by looking up names in dicts
print(f"gcm 599: {dataset_cfg_name = }")
for k, v in MAZE_DATASET_CONFIGS.items():
print(f"{k}: {v.summary()}")

config = ConfigHolder(
name=name,
dataset_cfg=copy.deepcopy(MAZE_DATASET_CONFIGS[dataset_cfg_name]),
model_cfg=copy.deepcopy(GPT_CONFIGS[model_cfg_name]),
train_cfg=copy.deepcopy(TRAINING_CONFIGS[train_cfg_name]),
)
print(f"gcm 612: {config.dataset_cfg.summary() = }")
except KeyError as e:
# exception handling for missing keys case
raise KeyError(
Expand All @@ -620,14 +612,12 @@ def get_config_multisource(
raise ValueError(
"Must provide exactly one of cfg, cfg_file, or cfg_names. this state should be unreachable btw."
)
print(f"gcm 604: {config.dataset_cfg.summary() = }")
# update config with kwargs
if kwargs_in:
kwargs_dict: dict = kwargs_to_nested_dict(
kwargs_in, sep=".", strip_prefix="cfg.", when_unknown_prefix="raise"
)
config.update_from_nested_dict(kwargs_dict)
print(f"gcm 611: {config.dataset_cfg.summary() = }")
return config


Expand Down
3 changes: 0 additions & 3 deletions maze_transformer/training/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def train_model(
- model config names: {model_cfg_names}
- train config names: {train_cfg_names}
"""
print(cfg.dataset_cfg.summary())
if help:
print(train_model.__doc__)
return
Expand Down Expand Up @@ -110,7 +109,6 @@ def train_model(
logger.progress("Summary logged, getting dataset")

# load dataset
print(cfg.dataset_cfg.summary())
if dataset is None:
dataset = MazeDataset.from_config(
cfg=cfg.dataset_cfg,
Expand Down Expand Up @@ -151,7 +149,6 @@ def train_model(
)

logger.progress(f"finished getting training dataset with {len(dataset)} samples")
print(f"{len(dataset) = }")
# validation dataset, if applicable
val_dataset: MazeDataset | None = None
if cfg.train_cfg.validation_dataset_cfg is not None:
Expand Down

0 comments on commit 85ee878

Please sign in to comment.