diff --git a/README.md b/README.md index 0557c984..12163d93 100644 --- a/README.md +++ b/README.md @@ -567,6 +567,17 @@ ```py mm.save_weights("foo.h5") ``` + - **Training with compile and fit** + ```py + import torch + if torch.cuda.is_available(): + _ = mm.to("cuda") + xx = torch.rand([64, *mm.input_shape[1:]]) + yy = torch.functional.F.one_hot(torch.randint(0, mm.output_shape[-1], size=[64]), mm.output_shape[-1]).float() + loss = lambda y_pred, y_true: (y_true - y_pred.float()).abs().mean() + mm.compile(optimizer="AdamW", loss=loss, metrics='acc', grad_accumulate=4) + mm.fit(xx, yy, epochs=2, batch_size=4) + ``` *** # Recognition Models