Skip to content

Commit

Permalink
feat (TF backend)(module.py): adding a new TorchModuleHelpers class…
Browse files Browse the repository at this point in the history
… to expose `nn.Module` specific methods to the subclassed keras.Model and keras.Layer classes.
  • Loading branch information
YushaArif99 committed Oct 11, 2024
1 parent 05e12c8 commit aa9848a
Showing 1 changed file with 322 additions and 2 deletions.
324 changes: 322 additions & 2 deletions ivy/functional/backends/tensorflow/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
Any,
Tuple,
List,
Set,
Dict,
Type,
Iterator,
Optional,
Union,
TYPE_CHECKING,
)
import itertools
import warnings
import typing
import inspect
from collections import OrderedDict
from packaging.version import parse
Expand Down Expand Up @@ -280,7 +286,321 @@ def recursive_deserialize(d):
else:
return deserialize_obj(d)

class TorchModuleHelpers:

def add_module(self, name: str, module: Optional["Model"]) -> None:
if not isinstance(module, (Model, Layer, keras.Model, keras.layers.Layer)) and module is not None:
raise TypeError(f"{type(module)} is not a Module subclass")
elif not isinstance(name, str):
raise TypeError(f"module name should be a string. Got {type(name)}")
elif hasattr(self, name) and name not in self._modules:
raise KeyError(f"attribute '{name}' already exists")
elif "." in name:
raise KeyError(f'module name can\'t contain ".", got: {name}')
elif name == "":
raise KeyError('module name can\'t be empty string ""')

self._modules[name] = module

super().__setattr__(name, module)

def apply(self, fn: Callable[["Model"], None]):
for module in self.children():
if hasattr(module, "apply"):
module.apply(fn)
else:
fn(module)
fn(self)
return self

def _apply(self, fn, recurse=True):
if recurse:
if hasattr(self, "children"):
for module in self.children():
if hasattr(module, "_apply"):
module._apply(fn)
for key, param in self.v.items():
if param is not None:
self.v[key] = fn(param)
for key, buf in self.buffers.items():
if buf is not None:
self.buffers[key] = fn(buf)
return self

def _named_members(
self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True
):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = (
self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate)
if recurse
else [(prefix, self)]
)
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or id(v) in memo:
continue
if remove_duplicate:
memo.add(id(v))
name = module_prefix + ("." if module_prefix else "") + k
yield name, v

def register_module(self, name: str, module: Optional["Model"]) -> None:
r"""Alias for :func:`add_module`."""
self.add_module(name, module)

def get_submodule(self, target: str) -> "Model":
if target == "":
return self

atoms: List[str] = target.split(".")
mod: Model = self

for item in atoms:
if not hasattr(mod, item):
raise AttributeError(
mod._get_name() + " has no attribute `" + item + "`"
)

mod = getattr(mod, item)

if not isinstance(mod, (Model, Layer, keras.Model, keras.layers.Layer)):
raise TypeError("`" + item + "` is not a Module")

return mod

def get_parameter(self, target: str):
target = target.replace(".", "/")
return self.v[target]

def parameters(self, recurse: bool = True):
for _, param in self.named_parameters(recurse=recurse):
yield param

def named_parameters(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
):
if not getattr(self, "_built", False):
self.build(
*self._args, dynamic_backend=self._dynamic_backend, **self._kwargs
)
gen = self._named_members(
lambda module: module.v.items(),
prefix=prefix,
recurse=recurse,
remove_duplicate=remove_duplicate,
)
yield from gen

def named_buffers(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
):
if not getattr(self, "_built", False):
self.build(
*self._args, dynamic_backend=self._dynamic_backend, **self._kwargs
)
gen = self._named_members(
lambda module: module.buffers.items(),
prefix=prefix,
recurse=recurse,
remove_duplicate=remove_duplicate,
)
yield from gen

def children(self) -> Iterator["Model"]:
for _, module in self.named_children():
yield module

def named_children(self) -> Iterator[Tuple[str, "Model"]]:
if not getattr(self, "_built", False):
self.build(
*self._args, dynamic_backend=self._dynamic_backend, **self._kwargs
)
memo = set()
for name, module in self._module_dict.items():
if module is not None and id(module) not in memo:
memo.add(id(module))
yield name, module

def modules(self) -> Iterator["Model"]:
for _, module in self.named_modules():
yield module

def named_modules(
self,
memo: Optional[Set["Model"]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
if not getattr(self, "_built", False):
self.build(
*self._args, dynamic_backend=self._dynamic_backend, **self._kwargs
)
if memo is None:
memo = set()
if id(self) not in memo:
if remove_duplicate:
memo.add(id(self))
yield prefix, self
for name, module in self._module_dict.items():
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
if not hasattr(module, "named_modules"):
yield submodule_prefix, self
else:
yield from module.named_modules(
memo, submodule_prefix, remove_duplicate
)

def _load_from_state_dict(
self, state_dict, prefix, strict, missing_keys, unexpected_keys, error_msgs
):
def _retrive_layer(model, key):
if len(key.split(".")) == 1:
return model, key

module_path, weight_name = key.rsplit(".", 1)

# Retrieve the layer using the module path
layer = model
for attr in module_path.split("."):
layer = getattr(layer, attr)

return layer, weight_name

persistent_buffers = {k: v for k, v in self._buffers.items()}
local_name_params = itertools.chain(
self._parameters.items(), persistent_buffers.items()
)
local_state = {k: v for k, v in local_name_params if v is not None}

for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not isinstance(input_param, tf.Tensor):
error_msgs.append(
f'While copying the parameter named "{key}", '
"expected ArrayLike object from checkpoint but "
f"received {type(input_param)}"
)
continue

if not isinstance(input_param, KerasVariable):
input_param = KerasVariable(input_param)

layer, weight_name = _retrive_layer(self, name)
try:
setattr(layer, weight_name, input_param)
except Exception as ex:
error_msgs.append(
f'While copying the parameter named "{key}", '
f"whose dimensions in the model are {param.shape} and "
f"whose dimensions in the checkpoint are {input_param.shape}, "
f"an exception occurred : {ex.args}."
)
elif strict:
missing_keys.append(key)

if strict:
for key in state_dict.keys():
if key.startswith(prefix):
input_name = key[len(prefix) :].split(".", 1)
if len(input_name) > 1:
if input_name[0] not in self._modules:
unexpected_keys.append(key)
elif input_name[0] not in local_state:
unexpected_keys.append(key)

def load_state_dict(
self,
state_dict: typing.Mapping[str, Any],
strict: bool = True,
assign: bool = False,
):
r"""Copy parameters and buffers from :attr:`state_dict` into this module and its descendants.
If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~Module.state_dict` function.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~Module.state_dict` function. Default: ``True``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing any keys that are expected
by this module but missing from the provided ``state_dict``.
* **unexpected_keys** is a list of str containing the keys that are not
expected by this module but present in the provided ``state_dict``.
"""
if not isinstance(state_dict, typing.Mapping):
raise TypeError(
f"Expected state_dict to be dict-like, got {type(state_dict)}."
)

missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []

state_dict = tf.nest.map_structure(
lambda x: tf.convert_to_tensor(x.numpy()),
state_dict,
)
state_dict = OrderedDict(state_dict)

def load(module, local_state_dict, prefix=""):
module._load_from_state_dict(
local_state_dict,
prefix,
strict,
missing_keys,
unexpected_keys,
error_msgs,
)
# TODO: maybe we should implement this similar to PT
# and make this recursive.

load(self, state_dict)
del load

if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
if strict:
missing_keys = sorted(missing_keys)
unexpected_keys = sorted(unexpected_keys)
if len(missing_keys) > 0:
warnings.warn(
"Missing key(s) in state_dict: {}\n".format(
", ".join(f"'{k}'" for k in missing_keys)
)
)
if len(unexpected_keys) > 0:
warnings.warn(
"Unexpected key(s) in state_dict: {}\n".format(
", ".join(f"'{k}'" for k in unexpected_keys)
)
)

def requires_grad_(self, requires_grad: bool = True):
for p in self.parameters():
p.requires_grad_(requires_grad)
return self

def _get_name(self):
return self.__class__.__name__

class ModelHelpers:
@staticmethod
@tf.autograph.experimental.do_not_convert
Expand Down Expand Up @@ -446,7 +766,7 @@ def _addindent(s_, numSpaces):
return s


class Layer(tf.keras.layers.Layer, ModelHelpers):
class Layer(tf.keras.layers.Layer, ModelHelpers, TorchModuleHelpers):
_build_mode = None
_with_partial_v = None
_store_vars = True
Expand Down Expand Up @@ -1138,7 +1458,7 @@ def __repr__(self):
return main_str


class Model(tf.keras.Model, ModelHelpers):
class Model(tf.keras.Model, ModelHelpers, TorchModuleHelpers):
_build_mode = None
_with_partial_v = None
_store_vars = True
Expand Down

0 comments on commit aa9848a

Please sign in to comment.