From 042e8a5b1ca77fa9507af3e1f47a21f72b45b24e Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Thu, 26 Sep 2024 10:23:09 -0700 Subject: [PATCH] Schema test for optimizer classes (#2429) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2429 BC API tests for optimizer classes for stable release Reviewed By: iamzainhuda Differential Revision: D63438376 fbshipit-source-id: a34dbca7dbc949a13d0052ed3765591c2cb2104e --- .../schema/api_tests/test_optimizer_schema.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 torchrec/schema/api_tests/test_optimizer_schema.py diff --git a/torchrec/schema/api_tests/test_optimizer_schema.py b/torchrec/schema/api_tests/test_optimizer_schema.py new file mode 100644 index 000000000..a204c67e7 --- /dev/null +++ b/torchrec/schema/api_tests/test_optimizer_schema.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import inspect +import unittest +from typing import Any, Collection, List, Mapping, Optional, Set, Tuple, Union + +import torch +from torch import optim + +from torchrec.distributed.types import ShardedTensor +from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer +from torchrec.schema.utils import is_signature_compatible + + +class StableKeyedOptimizer(optim.Optimizer): + def __init__( + self, + params: Mapping[str, Union[torch.Tensor, ShardedTensor]], + # pyre-ignore [2] + state: Mapping[Any, Any], + param_groups: Collection[Mapping[str, Any]], + ) -> None: + pass + + def init_state( + self, + sparse_grad_parameter_names: Optional[Set[str]] = None, + ) -> None: + pass + + def save_param_groups(self, save: bool) -> None: + pass + + # pyre-ignore [2] + def add_param_group(self, param_group: Any) -> None: + pass + + def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: + pass + + +class StableCombinedOptimizer(KeyedOptimizer): + def __init__( + self, + optims: List[Union[KeyedOptimizer, Tuple[str, KeyedOptimizer]]], + ) -> None: + pass + + @property + def optimizers(self) -> List[Tuple[str, StableKeyedOptimizer]]: + return [] + + @staticmethod + def prepend_opt_key(name: str, opt_key: str) -> str: + return "" + + @property + def param_groups(self) -> Collection[Mapping[str, Any]]: + return [] + + @property + def params(self) -> Mapping[str, Union[torch.Tensor, ShardedTensor]]: + return {} + + def post_load_state_dict(self) -> None: + pass + + def save_param_groups(self, save: bool) -> None: + pass + + # pyre-ignore [2] + def step(self, closure: Any = None) -> None: + pass + + def zero_grad(self, set_to_none: bool = False) -> None: + pass + + +class TestOptimizerSchema(unittest.TestCase): + def test_keyed_optimizer(self) -> None: + stable_keyed_optimizer_funcs = inspect.getmembers( + StableKeyedOptimizer, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_keyed_optimizer_funcs: + self.assertTrue(getattr(KeyedOptimizer, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(KeyedOptimizer, func_name)), + ) + ) + + def test_combined_optimizer(self) -> None: + stable_combined_optimizer_funcs = inspect.getmembers( + StableCombinedOptimizer, predicate=inspect.isfunction + ) + + for func_name, stable_func in stable_combined_optimizer_funcs: + self.assertTrue(getattr(CombinedOptimizer, func_name, None) is not None) + self.assertTrue( + is_signature_compatible( + inspect.signature(stable_func), + inspect.signature(getattr(CombinedOptimizer, func_name)), + ) + )