Skip to content

Commit

Permalink
fix test_get_model_state_dict_ignore
Browse files Browse the repository at this point in the history
  • Loading branch information
ez2rok committed Sep 24, 2024
1 parent 23ddbfb commit e07dc08
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
1 change: 0 additions & 1 deletion composer/models/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def get_metrics(self, is_train: bool = False) -> dict[str, Metric]:
metrics_dict = {metrics.__class__.__name__: metrics}
else:
metrics_dict = {}
assert metrics is not None
for name, metric in metrics.items():
assert isinstance(metric, Metric)
metrics_dict[name] = metric
Expand Down
4 changes: 3 additions & 1 deletion tests/algorithms/test_algorithms_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from composer import Algorithm, Trainer
from composer.algorithms import GyroDropout, LayerFreezing
from tests.algorithms.algorithm_settings import get_alg_dataloader, get_alg_kwargs, get_alg_model, get_algs_with_marks

from icecream import install
install()

@pytest.mark.gpu
@pytest.mark.parametrize('alg_cls', get_algs_with_marks())
Expand All @@ -15,6 +16,7 @@ def test_algorithm_trains(alg_cls: type[Algorithm]):
alg_kwargs = get_alg_kwargs(alg_cls)
model = get_alg_model(alg_cls)
dataloader = get_alg_dataloader(alg_cls)
ic(model, dataloader)
trainer = Trainer(
model=model,
train_dataloader=dataloader,
Expand Down
7 changes: 4 additions & 3 deletions tests/checkpoint/test_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from tests.common.compare import deep_compare
from tests.common.markers import world_size
from tests.common.models import EvenSimplerMLP, SimpleComposerMLP, configure_tiny_gpt2_hf_model

from icecream import install
install()

@pytest.mark.gpu
@pytest.mark.parametrize('use_composer_model', [True, False])
Expand Down Expand Up @@ -60,10 +61,10 @@ def test_get_model_state_dict_ignore(use_composer_model: bool):
model = EvenSimplerMLP(num_features=8, device='cuda')

model_state_dict = get_model_state_dict(model, sharded_state_dict=False, ignore_keys='module.2.weight')
assert set(model_state_dict.keys()) == {'module.0.weight'}
assert set(model_state_dict.keys()) == {'fc1.weight', 'fc2.weight', 'module.0.weight'}

model_state_dict = get_model_state_dict(model, sharded_state_dict=False, ignore_keys=['module.2*'])
assert set(model_state_dict.keys()) == {'module.0.weight'}
assert set(model_state_dict.keys()) == {'fc1.weight', 'fc2.weight', 'module.0.weight'}


@pytest.mark.gpu
Expand Down
7 changes: 4 additions & 3 deletions tests/common/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,10 @@ class EvenSimplerMLP(torch.nn.Module):

def __init__(self, num_features: int, device: str = 'cpu', num_out_features: int = 3):
super().__init__()
fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
fc2 = torch.nn.Linear(num_features, num_out_features, device=device, bias=False)
self.fc1 = torch.nn.Linear(num_features, num_features, device=device, bias=False)
self.fc2 = torch.nn.Linear(num_features, num_out_features, device=device, bias=False)

self.module = torch.nn.Sequential(fc1, torch.nn.ReLU(), fc2)
self.module = torch.nn.Sequential(self.fc1, torch.nn.ReLU(), self.fc2)

def forward(self, x):
return self.module(x)
Expand Down Expand Up @@ -480,6 +480,7 @@ def loss(self, outputs: torch.Tensor, batch: tuple[Any, torch.Tensor], *args, **

def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> None:
_, targets = batch
ic(metric, outputs, targets)
metric.update(outputs.squeeze(dim=0), targets.squeeze(dim=0))

def forward(self, batch: tuple[torch.Tensor, Any]) -> torch.Tensor:
Expand Down

0 comments on commit e07dc08

Please sign in to comment.