Skip to content

Commit

Permalink
PyTorch Train: add tests for using arg indexes for model inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Jun 14, 2024
1 parent 89a0354 commit 1d4bf9a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions tests/cases/torch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def forward(self, a, b):
),
],
)
def test_loss_drops(tmpdir, device):
@pytest.mark.parametrize("input_args", [True, False])
def test_loss_drops(tmpdir, device, input_args):
checkpoint_basename = str(tmpdir / "model")

a_key = ArrayKey("A")
Expand All @@ -104,7 +105,7 @@ def test_loss_drops(tmpdir, device):
model=model,
optimizer=optimizer,
loss=loss,
inputs={"a": a_key, "b": b_key},
inputs={"a": a_key, "b": b_key} if not input_args else {0: a_key, 1: b_key},
loss_inputs={0: c_predicted_key, 1: c_key},
outputs={0: c_predicted_key},
gradients={0: c_gradient_key},
Expand Down Expand Up @@ -167,7 +168,8 @@ def test_loss_drops(tmpdir, device):
),
],
)
def test_output(device):
@pytest.mark.parametrize("input_args", [True, False])
def test_spawn_subprocess(device, input_args):
logging.getLogger("gunpowder.torch.nodes.predict").setLevel(logging.INFO)

a_key = ArrayKey("A")
Expand All @@ -181,7 +183,7 @@ def test_output(device):
source = example_train_source(a_key, b_key, c_key)
predict = Predict(
model=model,
inputs={"a": a_key, "b": b_key},
inputs={"a": a_key, "b": b_key} if not input_args else {0: a_key, 1: b_key},
outputs={"linear": c_pred, 0: d_pred},
array_specs={
c_key: ArraySpec(nonspatial=True),
Expand Down

0 comments on commit 1d4bf9a

Please sign in to comment.