Skip to content

Commit

Permalink
fix type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
ez2rok committed Sep 25, 2024
1 parent ae5980a commit ff36f17
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,8 +504,8 @@ def update_batch_size_info(cfg: dict[str, Any]) -> dict[str, Any]:

def process_init_device(
model_cfg: dict[str, Any],
fsdp_config: Optional[dict],
tp_config: Optional[dict],
fsdp_config: Optional[dict] = None,
tp_config: Optional[dict] = None,
):
# Restrict model init_device to 'meta' and 'cpu',
# using 'cuda' vs. 'cuda:id' is tricky and can lead to common user errors
Expand Down
16 changes: 13 additions & 3 deletions tests/models/utils/test_tp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,22 @@ def test_ffn_tp_strategy_layer_plan():
):
assert n1 == n2
assert type(lp1) == type(lp2)
if isinstance(lp1, PrepareModuleInput):
if isinstance(
lp1,
PrepareModuleInput,
) and isinstance(lp2, PrepareModuleInput):
assert lp1.input_layouts == lp2.input_layouts
assert lp1.desired_input_layouts == lp2.desired_input_layouts
assert lp1.use_local_output == lp2.use_local_output
elif isinstance(lp1,
ColwiseParallel) or isinstance(lp1, RowwiseParallel):
elif (
isinstance(lp1, ColwiseParallel) and
isinstance(lp2, ColwiseParallel)
) or (
isinstance(lp1, RowwiseParallel) and
isinstance(lp2, RowwiseParallel)
):
assert lp1.input_layouts == lp2.input_layouts
assert lp1.output_layouts == lp2.output_layouts
assert lp1.use_local_output == lp2.use_local_output
else:
raise ValueError(f'Layer plan of wrong type: {type(layer_plan)}')

0 comments on commit ff36f17

Please sign in to comment.