diff --git a/morpho_symm/nn/EMLP.py b/morpho_symm/nn/EMLP.py index 9d7782b..7fc3aab 100644 --- a/morpho_symm/nn/EMLP.py +++ b/morpho_symm/nn/EMLP.py @@ -1,11 +1,13 @@ import logging -from typing import List, Union +from typing import Union import escnn import numpy as np import torch from escnn.nn import EquivariantModule, FieldType +from morpho_symm.nn.EquivariantModules import IsotypicBasis + log = logging.getLogger(__name__) @@ -33,7 +35,13 @@ def __init__(self, Args: ---- in_type (escnn.nn.FieldType): Input field type containing the representation of the input space. - out_type (escnn.nn.FieldType): Output field type containing the representation of the output space. + out_type (escnn.nn.FieldType): Output field type containing the representation of the output space. If the + output type representation is composed only of multiples of the trivial representation then the this model + will default to a G-invariant function. This G-invariant function is composed of a G-equivariant feature + extractor, a G-invariant pooling layer, extracting invariant features from the equivariant features, and a + last linear (unconstrained) layer to map the invariant features to the output space. + TODO: Create a G-invariant EMLP class where we can control the network processing the G-invariant + features instead of defaulting to a single linear layer. num_hidden_units: Number of hidden units in the intermediate layers. The effective number of hidden units will be ceil(num_hidden_units/|G|). Since we assume intermediate embeddings are regular fields. activation (escnn.nn.EquivariantModule, str): Name of pointwise activation function to use. @@ -51,8 +59,21 @@ def __init__(self, self.in_type, self.out_type = in_type, out_type self.gspace = self.in_type.gspace self.group = self.gspace.fibergroup - self.num_layers = num_layers + + # Check if the network is a G-invariant function (i.e., out rep is composed only of the trivial representation) + out_irreps = set(out_type.representation.irreps) + if len(out_irreps) == 1 and self.group.trivial_representation.id == list(out_irreps)[0]: + self.invariant_fn = True + else: + self.invariant_fn = False + input_irreps = set(in_type.representation.irreps) + inner_irreps = set(out_type.irreps) + diff = input_irreps.symmetric_difference(inner_irreps) + if len(diff) > 0: + log.warning(f"Irreps {list(diff)} of group {self.gspace.fibergroup} are not in the input/output types." + f"This represents an information bottleneck. Consider extracting invariant features.") + if self.num_layers == 1 and not head_with_activation: log.warning(f"{self} model with 1 layer and no activation. This is equivalent to a linear map") @@ -66,15 +87,7 @@ def __init__(self, else: raise ValueError(f"Activation type {type(activation)} not supported.") - input_irreps = set(in_type.representation.irreps) - inner_irreps = set(out_type.irreps) - diff = input_irreps.symmetric_difference(inner_irreps) - if len(diff) > 0: - log.warning(f"Irreps {list(diff)} of group {self.gspace.fibergroup} are not in the input/output types." - f"This represents an information bottleneck. Consider extracting invariant features.") - layer_in_type = in_type - self.net = escnn.nn.SequentialModule() for n in range(self.num_layers - 1): layer_out_type = hidden_type @@ -89,22 +102,53 @@ def __init__(self, layer_in_type = layer_out_type # Add final layer - head_block = escnn.nn.SequentialModule() - head_block.add_module(f"linear_{num_layers - 1}", escnn.nn.Linear(layer_in_type, out_type, bias=bias)) - if head_with_activation: - if batch_norm: - head_block.add_module(f"batchnorm_{num_layers - 1}", escnn.nn.IIDBatchNorm1d(out_type)), - head_block.add_module(f"act_{num_layers - 1}", activation) - # head_layer.check_equivariance() - self.net.add_module("head", head_block) + self.net_head = None + if self.invariant_fn: + self.net_head = torch.nn.Sequential() + # TODO: Make the G-invariant pooling with Isotypic Basis a stand alone module. + # Module describing the change of basis to an Isotypic Basis required for efficient G-invariant pooling + self.change2isotypic_basis = IsotypicBasis(hidden_type) + # Number of G-invariant features from net output equals the number of G-stable subspaces. + num_inv_features = len(hidden_type.irreps) + self.net_head.add_module(f"linear_{num_layers - 1}", + torch.nn.Linear(num_inv_features, out_type.size, bias=bias)) + if head_with_activation: + if batch_norm: + self.net_head.add_module(f"batchnorm_{num_layers - 1}", torch.nn.BatchNorm1d(out_type.size)), + self.net_head.add_module(f"act_{num_layers - 1}", activation) + else: # Equivariant Network + self.net_head = escnn.nn.SequentialModule() + self.net_head.add_module(f"linear_{num_layers - 1}", escnn.nn.Linear(layer_in_type, out_type, bias=bias)) + if head_with_activation: + if batch_norm: + self.net_head.add_module(f"batchnorm_{num_layers - 1}", escnn.nn.IIDBatchNorm1d(out_type)), + self.net_head.add_module(f"act_{num_layers - 1}", activation) # Test the entire model is equivariant. # self.net.check_equivariance() + def forward(self, x): + """Forward pass of the EMLP model.""" + equivariant_features = self.net(x) + if self.invariant_fn: + iso_equivariant_features = self.change2isotypic_basis(equivariant_features) + invariant_features = self.irrep_norm_pooling(iso_equivariant_features.tensor, iso_equivariant_features.type) + output = self.net_head(invariant_features) + output = self.out_type(output) # Wrap over invariant field type + else: + output = self.net_head(equivariant_features) + return output + + def reset_parameters(self, init_mode=None): + """Initialize weights and biases of E-MLP model.""" + raise NotImplementedError() + @staticmethod def get_activation(activation, in_type: FieldType, desired_hidden_units: int): gspace = in_type.gspace group = gspace.fibergroup grid_length = group.order() if not group.continuous else 20 + if isinstance(group, escnn.group.DihedralGroup): + grid_length = grid_length // 2 unique_irreps = set(in_type.irreps) unique_irreps_dim = sum([group.irrep(*id).size for id in set(in_type.irreps)]) @@ -122,13 +166,31 @@ def get_activation(activation, in_type: FieldType, desired_hidden_units: int): type='regular' if not group.continuous else 'rand', N=grid_length) - def forward(self, x): - """Forward pass of the EMLP model.""" - return self.net(x) - - def reset_parameters(self, init_mode=None): - """Initialize weights and biases of E-MLP model.""" - raise NotImplementedError() + @staticmethod + def irrep_norm_pooling(x: torch.Tensor, field_type: FieldType) -> torch.Tensor: + n_inv_features = len(field_type.irreps) + # TODO: Ensure isotypic basis i.e irreps of the same type are consecutive to each other. + inv_features = [] + for field_start, field_end, rep in zip(field_type.fields_start, + field_type.fields_end, + field_type.representations): + # Each field here represents a representation of an Isotypic Subspace. This rep is only composed of a single + # irrep type. + x_field = x[..., field_start:field_end] + num_G_stable_spaces = len(rep.irreps) # Number of G-invariant features = multiplicity of irrep + # Again this assumes we are already in an Isotypic basis + assert len(np.unique(rep.irreps, axis=0)) == 1, "This only works for now on the Isotypic Basis" + # This basis is useful because we can apply the norm in a vectorized way + # Reshape features to [batch, num_G_stable_spaces, num_features_per_G_stable_space] + x_field_p = torch.reshape(x_field, (x_field.shape[0], num_G_stable_spaces, -1)) + # Compute G-invariant measures as the norm of the features in each G-stable space + inv_field_features = torch.norm(x_field_p, dim=-1) + # Append to the list of inv features + inv_features.append(inv_field_features) + # Concatenate all the invariant features + inv_features = torch.cat(inv_features, dim=-1) + assert inv_features.shape[-1] == n_inv_features, f"Expected {n_inv_features} got {inv_features.shape[-1]}" + return inv_features def evaluate_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]: """Returns the output shape of the model given an input shape.""" @@ -136,83 +198,43 @@ def evaluate_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...] return batch_size, self.out_type.size -class MLP(torch.nn.Module): - """Standard baseline MLP. Representations and group are used for shapes only.""" - - def __init__(self, d_in, d_out, num_hidden_units=128, num_layers=3, - activation: Union[torch.nn.Module, List[torch.nn.Module]] = torch.nn.ReLU, - with_bias=True, init_mode="fan_in"): - """Constructor of a Multi-Layer Perceptron (MLP) model. - - This utility class allows to easily instanciate a G-equivariant MLP architecture. - - Args: - d_in: Dimension of the input space. - d_out: Dimension of the output space. - num_hidden_units: Number of hidden units in the intermediate layers. - num_layers: Number of layers in the MLP including input and output/head layers. That is, the number of - activation (escnn.nn.EquivariantModule, list(escnn.nn.EquivariantModule)): If a single activation module is - provided it will be used for all layers except the output layer. If a list of activation modules is provided - then `num_layers` activation equivariant modules should be provided. - with_bias: Whether to include a bias term in the linear layers. - init_mode: Not used until now. Will be used to initialize the weights of the MLP - """ - super().__init__() - self.d_in = d_in - self.d_out = d_out - self.init_mode = init_mode - self.hidden_channels = num_hidden_units - self.activation = activation - - logging.info("Initializing MLP") - - dim_in = self.d_in - dim_out = num_hidden_units - self.net = torch.nn.Sequential() - for n in range(num_layers - 1): - dim_out = num_hidden_units - block = torch.nn.Sequential() - block.add_module(f"linear_{n}", torch.nn.Linear(dim_in, dim_out, bias=with_bias)) - block.add_module(f"batchnorm_{n}", torch.nn.BatchNorm1d(dim_out)) - block.add_module(f"act_{n}", activation()) - - self.net.add_module(f"block_{n}", block) - dim_in = dim_out - # Add last layer - linear_out = torch.nn.Linear(in_features=dim_out, out_features=self.d_out, bias=with_bias) - self.net.add_module("head", linear_out) - - self.reset_parameters(init_mode=self.init_mode) - - def forward(self, input): - output = self.net(input) - return output - - def get_hparams(self): - return {'num_layers': len(self.net), - 'hidden_ch': self.hidden_channels, - 'init_mode': self.init_mode} - - def reset_parameters(self, init_mode=None): - assert init_mode is not None or self.init_mode is not None - self.init_mode = self.init_mode if init_mode is None else init_mode - for module in self.net: - if isinstance(module, torch.nn.Sequential): - tensor = module[0].weight - activation = module[-1].__class__.__name__ - elif isinstance(module, torch.nn.Linear): - tensor = module.weight - activation = "Linear" - else: - raise NotImplementedError(module.__class__.__name__) - - if "fan_in" == self.init_mode or "fan_out" == self.init_mode: - torch.nn.init.kaiming_uniform_(tensor, mode=self.init_mode, nonlinearity=activation.lower()) - elif 'normal' in self.init_mode.lower(): - split = self.init_mode.split('l') - std = 0.1 if len(split) == 1 else float(split[1]) - torch.nn.init.normal_(tensor, 0, std) - else: - raise NotImplementedError(self.init_mode) - - log.info(f"MLP initialized with mode: {self.init_mode}") +if __name__ == "__main__": + G = escnn.group.DihedralGroup(6) + gspace = escnn.gspaces.no_base_space(G) + # Test Invariant EMLP + in_type = escnn.nn.FieldType(gspace, [G.regular_representation] * 5) + out_type = escnn.nn.FieldType(gspace, [G.trivial_representation] * 6) + emlp = EMLP(in_type, out_type, + num_hidden_units=128, + num_layers=3, + activation="ReLU", + head_with_activation=False) + emlp.eval() # Shut down batch norm + x = in_type(torch.randn(1, in_type.size)) + y = emlp(x) + + for g in G.elements: + g_x = in_type(in_type.transform_fibers(x.tensor, g)) # Compute g · x + g_y = emlp(g_x) # Compute g · y + assert torch.allclose(y.tensor, g_y.tensor, rtol=1e-4, atol=1e-4), \ + f"{g} invariance failed {y.tensor} != {g_y.tensor}" + + # Test Equivariant EMLP + in_type = escnn.nn.FieldType(gspace, [G.regular_representation] * 5) + out_type = escnn.nn.FieldType(gspace, [G.regular_representation] * 2) + emlp = EMLP(in_type, out_type, + num_hidden_units=128, + num_layers=3, + activation="ReLU", + head_with_activation=False) + emlp.eval() # Shut down batch norm + + x = in_type(torch.randn(1, in_type.size)) + y = emlp(x) + + for g in G.elements: + g_x = in_type(in_type.transform_fibers(x.tensor, g)) # Compute g · x + g_y_gt = out_type(out_type.transform_fibers(y.tensor, g)) # Compute ground truth g · y + g_y = emlp(g_x) # Compute g · y + assert torch.allclose(g_y_gt.tensor, g_y.tensor, rtol=1e-4, atol=1e-4), \ + f"{g} invariance failed {g_y_gt.tensor} != {g_y.tensor}" diff --git a/morpho_symm/nn/EquivariantModules.py b/morpho_symm/nn/EquivariantModules.py index dd1b82e..bf185b1 100644 --- a/morpho_symm/nn/EquivariantModules.py +++ b/morpho_symm/nn/EquivariantModules.py @@ -1,395 +1,60 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# @Time : 31/1/22 -# @Author : Daniel Ordonez -# @email : daniels.ordonez@gmail.com -# Some code was adapted from https://github.com/ElisevanderPol/symmetrizer/blob/master/symmetrizer/nn/modules.py -import itertools -import logging -import math -import pathlib -from typing import Optional, Union -from zipfile import BadZipFile +from typing import Tuple +import escnn import numpy as np import torch +from escnn.nn import EquivariantModule, FieldType, GeometricTensor +from torch.nn import Parameter -# from emlp import Group -# from emlp.reps.representation import Rep -from emlp.reps.representation import Base as BaseRep -from scipy.sparse import issparse -from torch.nn import functional as F -from torch.nn.modules.utils import _single +from morpho_symm.utils.rep_theory_utils import isotypic_decomp_representation -from morpho_symm.groups.SemiDirectProduct import SemiDirectProduct -from morpho_symm.groups.SparseRepresentation import SparseRep -from morpho_symm.utils.algebra_utils import coo2torch_coo, slugify -log = logging.getLogger(__name__) +class IsotypicBasis(EquivariantModule): + r"""Utility non-trainable module to do the change of basis to a symmetry enabled basis (or isotypic basis).""" + def __init__(self, in_type: FieldType): + """Instanciate a non-trainable module effectively applying the change of basis to a symmetry enabled basis. -class BasisLinear(torch.nn.Module): - """Group-equivariant linear layer.""" - - def __init__(self, rep_in: BaseRep, rep_out: BaseRep, bias=True): - super().__init__() - - # TODO: Add parameter for direct/whreat product - G = SemiDirectProduct(Gin=rep_in.G, Gout=rep_out.G) - self.repW = SparseRep(G) - self.rep_in = rep_in - self.rep_out = rep_out - - self._new_coeff, self._new_bias_coeff = True, True - # Layer can be "unfreeze" and thus keep variable in case that happens. - self.unfrozed_equivariance = False - self.unfrozen_w = None - self.unfrozen_bias = None - - # Compute the nullspace - Q = self.repW.equivariant_basis() - self._sum_basis_sqrd = Q.power(2).sum() if issparse(Q) else np.sum(np.power(Q, 2)) - basis = coo2torch_coo(Q) if issparse(Q) else torch.tensor(np.asarray(Q)) - self.basis = torch.nn.Parameter(basis, requires_grad=False) - - # Create the network parameters. Coefficients for each base and a b - self.basis_coeff = torch.nn.Parameter(torch.randn((self.basis.shape[-1]))) - - if bias: - Qbias = rep_out.equivariant_basis() - bias_basis = coo2torch_coo(Qbias) if issparse(Qbias) else torch.tensor(np.asarray(Qbias)) - self.bias_basis = torch.nn.Parameter(bias_basis, requires_grad=False) - self.bias_basis_coeff = torch.nn.Parameter(torch.randn((self.bias_basis.shape[-1]))) - self._bias = self.bias - else: - self.bias_basis, self.bias_basis_coeff = None, None - - # TODO: Check if necessary - # self.proj_b = torchify_fn(jit(lambda b: self.P_bias @ b)) - # Initialize parameters - self.init_std = None - self.reset_parameters() - - # Check Equivariance. - EquivariantModel.test_module_equivariance(module=self, rep_in=self.rep_in, rep_out=self.rep_out) - # Add hook to backward pass - self.register_full_backward_hook(EquivariantModel.backward_hook) - - def forward(self, x): - """Normal forward pass, using weights formed by the basis and corresponding coefficients.""" - if x.device != self.weight.device: - self._new_coeff, self._new_bias_coeff = True, True - return F.linear(x, weight=self.weight, bias=self.bias) - - @property - def weight(self): - if not self.unfrozed_equivariance: - # if self._new_coeff or self._weight is None: - self._weight = torch.matmul(self.basis, self.basis_coeff).reshape((self.rep_out.G.d, self.rep_in.G.d)) - # self._new_coeff = False - return self._weight - else: - return self.unfrozen_w - - @property - def bias(self): - if not self.unfrozed_equivariance: - if self.bias_basis is not None: - # if self._new_bias_coeff or self._bias is None: - self._bias = torch.matmul(self.bias_basis, self.bias_basis_coeff).reshape((self.rep_out.G.d,)) - self._new_bias_coeff = False - return self._bias - return None - else: - return self.unfrozen_bias - - def reset_parameters(self, mode="fan_in", activation="ReLU"): - if self.unfrozed_equivariance: - raise BrokenPipeError("initialization called after unfrozed equivariance") - # Compute the constant coming from the derivative of the activation. Torch return the square root of this value - gain = torch.nn.init.calculate_gain(nonlinearity=activation.lower()) - # Get input out dimensions. - dim_in, dim_out = self.rep_in.G.d, self.rep_out.G.d - # Gain due to parameter sharing scheme from equivariance constrain - lambd = self._sum_basis_sqrd - if mode.lower() == "fan_in": - basis_coeff_variance = dim_out / lambd - elif mode.lower() == "fan_out": - basis_coeff_variance = dim_in / lambd - elif mode.lower() == "harmonic_mean": - basis_coeff_variance = 2. / ((lambd / dim_out) + (lambd / dim_in)) - elif mode.lower() == "arithmetic_mean": - basis_coeff_variance = ((dim_in + dim_out) / 2.) / lambd - elif "normal" in mode.lower(): - split = mode.split('l') - std = 0.1 if len(split) == 1 else float(split[1]) - torch.nn.init.normal_(self.basis_coeff, 0, std) - return - else: - raise NotImplementedError(f"{mode} is not a recognized mode for Kaiming initialization") - - self.init_std = gain * math.sqrt(basis_coeff_variance) - bound = math.sqrt(3.0) * self.init_std - - prev_basis_coeff = torch.clone(self.basis_coeff) - torch.nn.init.uniform_(self.basis_coeff, -bound, bound) - - self._new_coeff, self._new_bias_coeff = True, True - assert not torch.allclose(prev_basis_coeff, self.basis_coeff), "Ups, smth is wrong." - - def unfreeze_equivariance(self): - w, bias = self.weight, self.bias - self.unfrozed_equivariance = True - self.unfrozen_w = torch.nn.Parameter(w, requires_grad=True) - self.register_parameter('unfrozen_w', self.unfrozen_w) - if bias is not None: - self.unfrozen_bias = torch.nn.Parameter(bias, requires_grad=True) - self.register_parameter('unfrozen_bias', self.unfrozen_bias) - - def __repr__(self): - string = f"E-Linear G[{self.repW.G}]-W{self.rep_out.size() * self.rep_in.size()}-" \ - f"Wtrain:{self.basis.shape[-1]}={self.basis_coeff.shape[0] / np.prod(self.repW.size()) * 100:.1f}%" \ - f"-init_std:{self.init_std:.3f}" - return string - - def to(self, *args, **kwargs): - # When device or type changes tensors need updating. - self._new_bias_coeff, self._new_coeff = True, True - return super(BasisLinear, self).to(*args, **kwargs) - - -class BasisConv1d(torch.nn.Module): - from torch.nn.common_types import _size_1_t - - def __init__(self, rep_in: BaseRep, rep_out: BaseRep, kernel_size: _size_1_t, stride: _size_1_t = 1, - padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, groups: int = 1, - bias: bool = True) -> None: + Args: + in_type: The representation of the input vector field. + """ super().__init__() + self.in_type = in_type + self.group = in_type.gspace.fibergroup + # Representation iso_rep = Q2iso^-1 @ iso_basis @ Q2iso + self.iso_rep = isotypic_decomp_representation(in_type.representation) + # Output type is a symmetry enabled basis with "no change of basis" (i.e., identity matrix) + self.out_type = FieldType(in_type.gspace, + [iso_rep for iso_rep in self.iso_rep.attributes['isotypic_reps'].values()]) + # Orthogonal transformation from the input basis to the isotypic basis + self.Q2iso = Parameter(torch.from_numpy(self.iso_rep.change_of_basis_inv).float(), requires_grad=False) - # Original Implementation Parameters ___________________________________________________________ - self.kernel_size_ = int(kernel_size) - self.stride_ = _single(stride) - self.padding_ = padding if isinstance(padding, str) else _single(padding) - self.dilation_ = _single(dilation) - self.groups_ = groups - - # Custom parameters ____________________________________________________________________________ - G = SemiDirectProduct(Gin=rep_in.G, Gout=rep_out.G) - self.repW = SparseRep(G) - self.rep_in = rep_in - self.rep_out = rep_out - - # Avoid recomputing W when basis coefficients have not changed. - self._new_coeff, self._new_bias_coeff = True, True - - # Compute the nullspace - Q = self.repW.equivariant_basis() - self._sum_basis_sqrd = Q.power(2).sum() if issparse(Q) else np.sum(np.power(Q)) - basis = coo2torch_coo(Q) if issparse(Q) else torch.tensor(np.asarray(Q)) - self.basis = torch.nn.Parameter(basis, requires_grad=False) - - # Create the network parameters. Coefficients for each base, and kernel dim - self.basis_coeff = torch.nn.Parameter(torch.rand(self.basis.shape[1], self.kernel_size_), requires_grad=True) - - if bias: - Qbias = rep_out.equivariant_basis() - bias_basis = coo2torch_coo(Qbias) if issparse(Qbias) else torch.tensor(np.asarray(Qbias)) - self.bias_basis = torch.nn.Parameter(bias_basis, requires_grad=False) - self.bias_basis_coeff = torch.nn.Parameter(torch.randn((self.bias_basis.shape[-1])), requires_grad=True) - else: - self.bias_basis, self.bias_basis_coeff = None, None - - self.reset_parameters() - # Check Equivariance. - EquivariantModel.test_module_equivariance(module=self, rep_in=self.rep_in, rep_out=self.rep_out, - in_shape=(1, rep_in.G.d, 2)) - # Add hook to backward pass - self.register_full_backward_hook(EquivariantModel.backward_hook) - - def forward(self, x): - if x.device != self.weight.device: - self._new_coeff, self._new_bias_coeff = True, True - return F.conv1d(input=x, weight=self.weight, bias=self.bias, stride=self.stride_, padding=self.padding_, - dilation=self.dilation_, groups=self.groups_) - - @property - def weight(self): - # if self._new_coeff: - self._weight = torch.matmul(self.basis, self.basis_coeff).reshape( - (self.rep_out.G.d, self.rep_in.G.d, self.kernel_size_)) - self._new_coeff = False - return self._weight - - @property - def bias(self): - if self.bias_basis is not None: - # if self._new_bias_coeff: - self._bias = torch.matmul(self.bias_basis, self.bias_basis_coeff).reshape((self.rep_out.G.d,)) - self._new_bias_coeff = False - return self._bias - return None - - def reset_parameters(self, mode="fan_in", activation="ReLU"): - # Compute the constant coming from the derivative of the activation. Torch return the square root of this value - gain = torch.nn.init.calculate_gain(nonlinearity=activation.lower()) - # Get input out dimensions. - dim_in, dim_out = self.rep_in.G.d, self.rep_out.G.d - # Gain due to parameter sharing scheme from equivariance constrain - lambd = self._sum_basis_sqrd - if mode.lower() == "fan_in": - basis_coeff_variance = dim_out / lambd - elif mode.lower() == "fan_out": - basis_coeff_variance = dim_in / lambd - elif mode.lower() == "harmonic_mean": - basis_coeff_variance = 2. / ((lambd / dim_out) + (lambd / dim_in)) - elif mode.lower() == "arithmetic_mean": - basis_coeff_variance = ((dim_in + dim_out) / 2.) / lambd - elif "normal" in mode.lower(): - split = mode.split('l') - std = 0.1 if len(split) == 1 else float(split[1]) - torch.nn.init.normal_(self.basis_coeff, 0, std) - return - else: - raise NotImplementedError(f"{mode} is not a recognized mode for Kaiming initialization") - - self.init_std = gain * math.sqrt(basis_coeff_variance) - bound = math.sqrt(3.0) * self.init_std - - prev_basis_coeff = torch.clone(self.basis_coeff) - torch.nn.init.uniform_(self.basis_coeff, -bound, bound) - if self.bias_basis is not None: - torch.nn.init.zeros_(self.bias_basis_coeff) - - self._new_coeff, self._new_bias_coeff = True, True - assert not torch.allclose(prev_basis_coeff, self.basis_coeff), "Ups, smth is wrong." - - def __repr__(self): - string = f"E-Conv1D G[{self.repW.G}]-W{self.rep_out.size() * self.rep_in.size()}-" \ - f"Wtrain:{self.basis.shape[-1]}={self.basis_coeff.shape[0] / np.prod(self.repW.size()) * 100:.1f}%" \ - f"-init_std:{self.init_std:.3f}" - return string - - -class EquivariantModel(torch.nn.Module): - - def __init__(self, rep_in: BaseRep, rep_out: BaseRep, cache_dir: Optional[Union[str, pathlib.Path]] = None): - super(EquivariantModel, self).__init__() - self.rep_in = rep_in - self.rep_out = rep_out - self.cache_dir = cache_dir - - # Cache dir - self.cache_dir = cache_dir if cache_dir is None else pathlib.Path(cache_dir).resolve(strict=True) - if self.cache_dir is None: - log.warning("No cache directory provided. Nothing will be saved") - elif not self.cache_dir.exists(): - raise OSError(f"Cache dir {self.cache_dir} does not exists") - else: - log.info(f"Equivariant Module - Basis Cache dir {self.cache_dir}") - self.load_cache_file() - - @staticmethod - def backward_hook(module: torch.nn.Module, _inputs, _outputs): - if hasattr(module, '_new_coeff') and hasattr(module, '_new_bias_coeff'): - module._new_coeff, module._new_bias_coeff = True, True - - @property - def _cache_file_name(self) -> str: - EXTENSION = ".npz" - model_rep = f'{self.rep_in.G}-{self.rep_out.G}' - return slugify(model_rep) + EXTENSION - - def load_cache_file(self): - if self.cache_dir is None: - log.info("Cache Loading Failed: No cache directory provided") - return - model_cache_file = self.cache_dir.joinpath(self._cache_file_name) - - if not model_cache_file.exists(): - log.warning(f"Model cache {model_cache_file.stem} not found") - return - - try: - lazy_cache = np.load(model_cache_file, allow_pickle=True, mmap_mode='c') - - run_cache = self.rep_in.solcache - if isinstance(run_cache, EMLPCache): - cache = run_cache.cache - else: - cache = run_cache - - # Remove from memory cache all file-saved caches. Taking advantage of lazy loading. - for k in list(cache.keys()): - if str(k) in lazy_cache: - cache.pop(k) - - Rep.solcache = EMLPCache(cache, lazy_cache) - log.info(f"Cache loaded for: {list(Rep.solcache.keys())}") - except Exception as e: - log.warning(f"Error while loading cache from {model_cache_file}: \n {e}") - - def save_cache_file(self): - if self.cache_dir is None: - log.info("Cache Saving Failed: No cache directory provided") - return - - model_cache_file = self.cache_dir.joinpath(self._cache_file_name) - - run_cache = self.rep_in.solcache - if isinstance(run_cache, EMLPCache): - lazy_cache, cache = run_cache.lazy_cache, run_cache.cache - else: - lazy_cache, cache = {}, run_cache + def forward(self, x: GeometricTensor) -> GeometricTensor: + """Change of basis of the input field to a symmetry enabled basis (or isotypic basis).""" + assert x.type == self.in_type, f"Input type {x.type} does not match module's input type {self.in_type}" + x_iso = torch.einsum('ij,...j->...i', self.Q2iso, x.tensor) + return self.out_type(x_iso) - if len(run_cache) == 0: - log.debug("Ignoring cache save as there is no new equivariant basis") - try: - combined_cache = {str(k): np.asarray(v) for k, v in itertools.chain(lazy_cache.items(), cache.items())} - np.savez_compressed(model_cache_file, **combined_cache) + def evaluate_output_shape(self, input_shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Shape of the output field.""" + return input_shape - # Since we moved all cache to disk with lazy loading. Remove from memory - self.rep_in.solcache = EMLPCache(cache={}, lazy_cache=np.load(str(model_cache_file), allow_pickle=True)) - log.info(f"Saved cache from {list(self.rep_in.solcache.keys())} to {model_cache_file}") - except BadZipFile as e: - self.rep_in.solcache.lazy_cache = {} - log.warning(f"Error while saving cache to {model_cache_file}: \n {e}") - except Exception as e: - log.warning(f"Error while saving cache to {model_cache_file}: \n {e}") - @staticmethod - def test_module_equivariance(module: torch.nn.Module, rep_in, rep_out, in_shape=None): - module.eval() - shape = (rep_in.G.d,) if in_shape is None else in_shape - x = torch.randn(shape) - for g_in, g_out in zip(rep_in.G.discrete_generators, rep_out.G.discrete_generators): - g_in, g_out = (g_in.todense(), g_out.todense()) if issparse(g_in) else (g_in, g_out) - g_in = torch.tensor(np.asarray(g_in), dtype=torch.float32).unsqueeze(0) - g_out = torch.tensor(np.asarray(g_out), dtype=torch.float32).unsqueeze(0) +if __name__ == '__main__': + G = escnn.group.DihedralGroup(5) + gspace = escnn.gspaces.no_base_space(G) - y = module.forward(x) + in_type = escnn.nn.FieldType(gspace, [G.regular_representation] * 2) - if x.ndim == 3: - g_x = (g_in @ x.unsqueeze(1)).squeeze(1) - else: - g_x = g_in @ x + rep_iso_basis = isotypic_decomp_representation(in_type.representation) - g_y_pred = module.forward(g_x) - g_y_true = g_out @ y - if not torch.allclose(g_y_true, g_y_pred, atol=1e-4, rtol=1e-4): - torch.max(g_y_true - g_y_pred).item() - g_in.squeeze(0).numpy() - g_out.squeeze(0).numpy() - error = (g_y_true - g_y_pred).detach().numpy() - raise RuntimeError(f"{module}\nis not equivariant to in/out group generators\n" - f"max(f(g·x) - g·y) = {np.max(error)}") + iso_module = IsotypicBasis(in_type) - if torch.allclose(g_y_pred, y, atol=1e-4, rtol=1e-4): - log.warning(f"\nModule {module} is INVARIANT! not EQUIVARIANT\n") - module.train() + x_np = np.random.randn(1, in_type.size) + x = in_type(torch.from_numpy(x_np).float()) + x_iso = iso_module(x) - @property - def model_class(self): - return self.__class__.__name__ + iso_rep = iso_module.iso_rep - # def __repr__(self): - # return f'{self.model_class}: {self.rep_in.G}-{self.rep_out.G}' + x_np_iso = (iso_rep.change_of_basis_inv @ x_np.T).T + assert np.allclose(x_np_iso, x_iso.tensor.numpy()), f"{x_np_iso - x_iso.tensor.numpy()}!=0" diff --git a/morpho_symm/nn/LightningModel.py b/morpho_symm/nn/LightningModel.py index 6aa09cb..5d5f441 100644 --- a/morpho_symm/nn/LightningModel.py +++ b/morpho_symm/nn/LightningModel.py @@ -7,9 +7,8 @@ import pytorch_lightning as pl import torch -from morpho_symm.nn.EMLP import EMLP, MLP, LinearBlock - -from .EquivariantModules import BasisLinear +from morpho_symm.nn.EMLP import EMLP +from morpho_symm.nn.MLP import MLP log = logging.getLogger(__name__) @@ -138,36 +137,6 @@ def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) return optimizer - def log_weights(self): - if not self.logger: return - tb_logger = self.logger.experiment - layer_index = 0 # Count layers by linear operators not position in network sequence - for layer in self.model.net: - layer_name = f"Layer{layer_index:02d}" - if isinstance(layer, EquivariantBlock) or isinstance(layer, BasisLinear): - lin = layer.linear if isinstance(layer, EquivariantBlock) else layer - W = lin.weight.view(-1).detach() - basis_coeff = lin.basis_coeff.view(-1).detach() - tb_logger.add_histogram(tag=f"{layer_name}/c", values=basis_coeff, global_step=self.current_epoch) - tb_logger.add_histogram(tag=f"{layer_name}/W", values=W, global_step=self.current_epoch) - layer_index += 1 - elif isinstance(layer, LinearBlock) or isinstance(layer, torch.nn.Linear): - lin = layer.linear if isinstance(layer, LinearBlock) else layer - W = lin.weight.view(-1).detach() - tb_logger.add_histogram(tag=f"{layer_name}/W", values=W, global_step=self.current_epoch) - layer_index += 1 - - def log_preactivations(self, ): - if not self.logger: return - tb_logger = self.logger.experiment - layer_index = 0 # Count layers by linear operators not position in network sequence - for layer in self.model.net: - layer_name = f"Layer{layer_index:02d}" - if isinstance(layer, EquivariantBlock) or isinstance(layer, LinearBlock): - tb_logger.add_histogram(tag=f"{layer_name}/pre-act", values=layer._preact, - global_step=self.current_epoch) - layer_index += 1 - def get_metrics(self): # don't show the version number on console logs. items = super().get_metrics() diff --git a/morpho_symm/nn/MLP.py b/morpho_symm/nn/MLP.py new file mode 100644 index 0000000..028b234 --- /dev/null +++ b/morpho_symm/nn/MLP.py @@ -0,0 +1,114 @@ +import logging +from typing import List, Union + +import torch + +log = logging.getLogger(__name__) + + +class MLP(torch.nn.Module): + """Standard baseline MLP. Representations and group are used for shapes only.""" + + def __init__(self, + in_dim: int, + out_dim: int, + num_hidden_units: int = 64, + num_layers: int = 3, + bias: bool = True, + batch_norm: bool = True, + head_with_activation: bool = False, + activation: Union[torch.nn.Module, List[torch.nn.Module]] = torch.nn.ReLU, + init_mode="fan_in"): + """Constructor of a Multi-Layer Perceptron (MLP) model. + + This utility class allows to easily instanciate a G-equivariant MLP architecture. + + Args: + ---- + in_dim: Dimension of the input space. + out_dim: Dimension of the output space. + num_hidden_units: Number of hidden units in the intermediate layers. + num_layers: Number of layers in the MLP including input and output/head layers. That is, the number of + activation (escnn.nn.EquivariantModule, list(escnn.nn.EquivariantModule)): If a single activation module is + provided it will be used for all layers except the output layer. If a list of activation modules is provided + then `num_layers` activation equivariant modules should be provided. + bias: Whether to include a bias term in the linear layers. + init_mode: Not used until now. Will be used to initialize the weights of the MLP + """ + super().__init__() + logging.info("Instantiating MLP (PyTorch)") + self.in_dim, self.out_dim = in_dim, out_dim + self.init_mode = init_mode if init_mode is not None else "fan_in" + self.hidden_channels = num_hidden_units + self.activation = activation if isinstance(activation, list) else [activation] * (num_layers - 1) + + self.num_layers = num_layers + if self.num_layers == 1 and not head_with_activation: + log.warning(f"{self} model with 1 layer and no activation. This is equivalent to a linear map") + + dim_in = self.in_dim + dim_out = num_hidden_units + + self.net = torch.nn.Sequential() + for n in range(self.num_layers - 1): + dim_out = num_hidden_units + + block = torch.nn.Sequential() + block.add_module(f"linear_{n}", torch.nn.Linear(dim_in, dim_out, bias=bias)) + if batch_norm: + block.add_module(f"batchnorm_{n}", torch.nn.BatchNorm1d(dim_out)) + block.add_module(f"act_{n}", activation()) + + self.net.add_module(f"block_{n}", block) + dim_in = dim_out + + # Add last layer + head_block = torch.nn.Sequential() + head_block.add_module(f"linear_{num_layers - 1}", torch.nn.Linear(in_features=dim_out, out_features=self.out_dim, bias=bias)) + if head_with_activation: + if batch_norm: + head_block.add_module(f"batchnorm_{num_layers - 1}", torch.nn.BatchNorm1d(dim_out)) + head_block.add_module(f"act_{num_layers - 1}", activation()) + + self.net.add_module("head", head_block) + + self.reset_parameters(init_mode=self.init_mode) + + def forward(self, input): + """Forward pass of the MLP model.""" + output = self.net(input) + return output + + def get_hparams(self): + return {'num_layers': self.num_layers, + 'hidden_ch': self.hidden_channels, + 'init_mode': self.init_mode} + + def reset_parameters(self, init_mode=None): + assert init_mode is not None or self.init_mode is not None + self.init_mode = self.init_mode if init_mode is None else init_mode + for module in self.net: + if isinstance(module, torch.nn.Sequential): + tensor = module[0].weight + activation = module[-1].__class__.__name__ + activation = "linear" if activation == "Identity" else activation + elif isinstance(module, torch.nn.Linear): + tensor = module.weight + activation = "Linear" + else: + raise NotImplementedError(module.__class__.__name__) + + if "fan_in" == self.init_mode or "fan_out" == self.init_mode: + try: + torch.nn.init.kaiming_uniform_(tensor, mode=self.init_mode, nonlinearity=activation.lower()) + except ValueError as e: + log.warning(f"Could not initialize {module.__class__.__name__} with {self.init_mode} mode. " + f"Using default Pytorch initialization") + elif 'normal' in self.init_mode.lower(): + split = self.init_mode.split('l') + std = 0.1 if len(split) == 1 else float(split[1]) + torch.nn.init.normal_(tensor, 0, std) + else: + raise NotImplementedError(self.init_mode) + + log.info(f"MLP initialized with mode: {self.init_mode}") \ No newline at end of file diff --git a/morpho_symm/utils/group_utils.py b/morpho_symm/utils/group_utils.py deleted file mode 100644 index 1fda0b1..0000000 --- a/morpho_symm/utils/group_utils.py +++ /dev/null @@ -1,101 +0,0 @@ -import itertools -from typing import Dict - -import numpy as np -from escnn.group import CyclicGroup, DihedralGroup, DirectProductGroup, Group, GroupElement -from escnn.group.representation import Representation - -from morpho_symm.groups.isotypic_decomposition import escnn_representation_form_mapping - - -def generate_cyclic_rep(G: CyclicGroup, rep): - """Generate cylic froup form a representation of its generator.""" - h = G.generators[0] - # Check the given matrix representations comply with group axioms - assert not np.allclose(rep[h], rep[G.identity]), "Invalid generator: h=e" - assert np.allclose(np.linalg.matrix_power(rep[h], G.order()), rep[G.identity]), \ - f"Invalid rotation generator h_ref^{G.order()} != I" - - curr_g = h - while len(rep) < G.order(): # Use generator to obtain all elements and element reps in group - g = curr_g @ h - rep[g] = rep[curr_g] @ rep[h] - curr_g = g - - return rep - - -def generate_dihedral_rep(G: DihedralGroup, rep): - """Generate dihedral group form a representation of its generators.""" - h_rot, h_ref = G.generators - # Check the given matrix representations comply with group axioms - # assert not np.allclose(rep[h_ref], rep[G.identity]), "Invalid reflection generator: h_ref=e" - # assert not np.allclose(rep[h_rot], rep[G.identity]), "Invalid rotation generator: h_rot=e" - assert np.allclose(rep[h_ref] @ rep[h_ref], rep[G.identity]), "Invalid reflection generator `h_ref @ h_ref != I`" - assert np.allclose(np.linalg.matrix_power(rep[h_rot], G.order()//2), rep[G.identity]),\ - f"Invalid rotation generator h_ref^{G.order} != I" - - curr_g, curr_ref_g = h_rot, h_ref @ h_rot - rep[curr_ref_g] = rep[h_ref] @ rep[h_rot] - while len(rep) < G.order(): # Use generator to obtain all elements and element reps in group - g = curr_g @ h_rot - gr = curr_ref_g @ h_rot - rep[g] = rep[curr_g] @ rep[h_rot] - rep[gr] = rep[curr_ref_g] @ rep[h_rot] - curr_g, curr_ref_g = g, gr - - return rep - - -def generate_direct_product_rep(G: DirectProductGroup, rep1, rep2): - """Generate direct product group form the two representations of each group generators.""" - rep = {} - for h1, h2 in itertools.product(rep1.keys(), rep2.keys()): - g = G.pair_elements(h1, h2) - rep[g] = rep1[h1] @ rep2[h2] - return rep - - -def group_rep_from_gens(G: Group, rep_H: Dict[GroupElement, np.ndarray]) -> Representation: - """Generate a representation of all group actions from the representations of the group generators H={h1, h2, ...}. - - Such that any group action `g` can be obtained by multiplying the generators `h1`, `h2`, ... in `H`. - Being all `h` in `H` - Args: - G (Group): Group to generate representation for. - rep_H (Dict[GroupElement, np.ndarray]): Representation of the group generators H={h1, h2, ...}. - - Returns: - Representation: ESCNN Representation of group actions generated by the input rep.. - """ - if G.identity not in rep_H: - rep_H[G.identity] = np.eye(list(rep_H.values())[0].shape[0]) - - if isinstance(G, CyclicGroup): - rep_H = generate_cyclic_rep(G, rep_H) - elif isinstance(G, DihedralGroup): - rep_H = generate_dihedral_rep(G, rep_H) - elif isinstance(G, DirectProductGroup): - # Extract the generators of first and second group, generate the groups independently and then combine them - H1, H2 = zip(*[G.split_element(h) for h in G.generators]) - rep_G1 = {G.G1.identity: rep_H[G.pair_elements(G.G1.identity, G.G2.identity)]} - rep_G2 = {G.G2.identity: rep_H[G.pair_elements(G.G1.identity, G.G2.identity)]} - rep_G1.update({h1: rep_H[G.inclusion1(h1)] for h1 in H1}) - rep_G2.update({h2: rep_H[G.inclusion2(h2)] for h2 in H2}) - - # generate each subgroup representation - group_rep_from_gens(G.G1, rep_G1) - group_rep_from_gens(G.G2, rep_G2) - - # Do direct product of the generated subgroups reps. - rep_H = generate_direct_product_rep(G, rep_G1, rep_G2) - else: - raise NotImplementedError(f"Group {G} not implemented yet.") - - # Convert Dict[GroupElement, np.ndarray] to escnn `Representation` - rep_escnn = escnn_representation_form_mapping(G, rep_H) - - return rep_escnn - - - diff --git a/morpho_symm/utils/rep_theory_utils.py b/morpho_symm/utils/rep_theory_utils.py new file mode 100644 index 0000000..4143100 --- /dev/null +++ b/morpho_symm/utils/rep_theory_utils.py @@ -0,0 +1,181 @@ +import functools +import itertools +from collections import OrderedDict +from typing import Dict + +import numpy as np +from escnn.group import CyclicGroup, DihedralGroup, DirectProductGroup, Group, GroupElement +from escnn.group.representation import Representation, directsum + +from morpho_symm.groups.isotypic_decomposition import escnn_representation_form_mapping +from morpho_symm.utils.algebra_utils import permutation_matrix + + +def generate_cyclic_rep(G: CyclicGroup, rep): + """Generate cylic froup form a representation of its generator.""" + h = G.generators[0] + # Check the given matrix representations comply with group axioms + assert not np.allclose(rep[h], rep[G.identity]), "Invalid generator: h=e" + assert np.allclose(np.linalg.matrix_power(rep[h], G.order()), rep[G.identity]), \ + f"Invalid rotation generator h_ref^{G.order()} != I" + + curr_g = h + while len(rep) < G.order(): # Use generator to obtain all elements and element reps in group + g = curr_g @ h + rep[g] = rep[curr_g] @ rep[h] + curr_g = g + + return rep + + +def generate_dihedral_rep(G: DihedralGroup, rep): + """Generate dihedral group form a representation of its generators.""" + h_rot, h_ref = G.generators + # Check the given matrix representations comply with group axioms + # assert not np.allclose(rep[h_ref], rep[G.identity]), "Invalid reflection generator: h_ref=e" + # assert not np.allclose(rep[h_rot], rep[G.identity]), "Invalid rotation generator: h_rot=e" + assert np.allclose(rep[h_ref] @ rep[h_ref], rep[G.identity]), "Invalid reflection generator `h_ref @ h_ref != I`" + assert np.allclose(np.linalg.matrix_power(rep[h_rot], G.order() // 2), rep[G.identity]), \ + f"Invalid rotation generator h_ref^{G.order} != I" + + curr_g, curr_ref_g = h_rot, h_ref @ h_rot + rep[curr_ref_g] = rep[h_ref] @ rep[h_rot] + while len(rep) < G.order(): # Use generator to obtain all elements and element reps in group + g = curr_g @ h_rot + gr = curr_ref_g @ h_rot + rep[g] = rep[curr_g] @ rep[h_rot] + rep[gr] = rep[curr_ref_g] @ rep[h_rot] + curr_g, curr_ref_g = g, gr + + return rep + + +def generate_direct_product_rep(G: DirectProductGroup, rep1, rep2): + """Generate direct product group form the two representations of each group generators.""" + rep = {} + for h1, h2 in itertools.product(rep1.keys(), rep2.keys()): + g = G.pair_elements(h1, h2) + rep[g] = rep1[h1] @ rep2[h2] + return rep + + +def group_rep_from_gens(G: Group, rep_H: Dict[GroupElement, np.ndarray]) -> Representation: + """Generate a representation of all group actions from the representations of the group generators H={h1, h2, ...}. + + Such that any group action `g` can be obtained by multiplying the generators `h1`, `h2`, ... in `H`. + Being all `h` in `H` + Args: + G (Group): Group to generate representation for. + rep_H (Dict[GroupElement, np.ndarray]): Representation of the group generators H={h1, h2, ...}. + + Returns: + Representation: ESCNN Representation of group actions generated by the input rep.. + """ + if G.identity not in rep_H: + rep_H[G.identity] = np.eye(list(rep_H.values())[0].shape[0]) + + if isinstance(G, CyclicGroup): + rep_H = generate_cyclic_rep(G, rep_H) + elif isinstance(G, DihedralGroup): + rep_H = generate_dihedral_rep(G, rep_H) + elif isinstance(G, DirectProductGroup): + # Extract the generators of first and second group, generate the groups independently and then combine them + H1, H2 = zip(*[G.split_element(h) for h in G.generators]) + rep_G1 = {G.G1.identity: rep_H[G.pair_elements(G.G1.identity, G.G2.identity)]} + rep_G2 = {G.G2.identity: rep_H[G.pair_elements(G.G1.identity, G.G2.identity)]} + rep_G1.update({h1: rep_H[G.inclusion1(h1)] for h1 in H1}) + rep_G2.update({h2: rep_H[G.inclusion2(h2)] for h2 in H2}) + + # generate each subgroup representation + group_rep_from_gens(G.G1, rep_G1) + group_rep_from_gens(G.G2, rep_G2) + + # Do direct product of the generated subgroups reps. + rep_H = generate_direct_product_rep(G, rep_G1, rep_G2) + else: + raise NotImplementedError(f"Group {G} not implemented yet.") + + # Convert Dict[GroupElement, np.ndarray] to escnn `Representation` + rep_escnn = escnn_representation_form_mapping(G, rep_H) + return rep_escnn + + +def irreps_stats(irreps_ids): + str_ids = [str(irrep_id) for irrep_id in irreps_ids] + unique_str_ids, counts, indices = np.unique(str_ids, return_counts=True, return_index=True) + unique_ids = [eval(s) for s in unique_str_ids] + return unique_ids, counts, indices + + +def isotypic_decomp_representation(rep: Representation) -> [Representation]: + """Returns a representation in a "symmetry enabled basis" (a.k.a Isotypic Basis). + + Takes a representation with an arbitrary basis (i.e., arbitrary change of basis and an arbitrary order of + irreducible representations in the escnn Representation) and returns a new representation in which the basis + is changed to a "symmetry enabled basis" (a.k.a Isotypic Basis). That is a representation in which the + vector space is decomposed into a direct sum of Isotypic Subspaces. Each Isotypic Subspace is a subspace of the + original vector space with a subspace representation composed of multiplicities of a single irreducible + representation. In oder words, each Isotypic Subspace is a subspace with a subgroup of symmetries of the original + vector space's symmetry group. + + Args: + rep (Representation): Input representation in any arbitrary basis. + + Returns: A `Representation` with a change of basis exposing an Isotypic Basis (a.k.a symmetry enabled basis). + The instance of the representation contains an additional attribute `isotypic_subspaces` which is an + `OrderedDict` of representations per each isotypic subspace. The keys are the active irreps' ids associated + with each Isotypic subspace. + """ + symm_group = rep.group + potential_irreps = rep.group.irreps() + isotypic_subspaces_indices = {irrep.id: [] for irrep in potential_irreps} + + for pot_irrep in potential_irreps: + cur_dim = 0 + for rep_irrep_id in rep.irreps: + rep_irrep = symm_group.irrep(*rep_irrep_id) + if rep_irrep == pot_irrep: + isotypic_subspaces_indices[rep_irrep_id].append(list(range(cur_dim, cur_dim + rep_irrep.size))) + cur_dim += rep_irrep.size + + # Remove inactive Isotypic Spaces + for irrep in potential_irreps: + if len(isotypic_subspaces_indices[irrep.id]) == 0: + del isotypic_subspaces_indices[irrep.id] + + # Each Isotypic Space will be indexed by the irrep it is associated with. + active_isotypic_reps = {} + for irrep_id, indices in isotypic_subspaces_indices.items(): + irrep = symm_group.irrep(*irrep_id) + multiplicities = len(indices) + active_isotypic_reps[irrep_id] = Representation(group=rep.group, + irreps=[irrep_id] * multiplicities, + name=f'IsoSubspace {irrep_id}', + change_of_basis=np.identity(irrep.size * multiplicities), + supported_nonlinearities=irrep.supported_nonlinearities) + + # Impose canonical order on the Isotypic Subspaces. + # If the trivial representation is active it will be the first Isotypic Subspace. + # Then sort by dimension of the space from smallest to largest. + ordered_isotypic_reps = OrderedDict(sorted(active_isotypic_reps.items(), key=lambda item: item[1].size)) + if symm_group.trivial_representation.id in ordered_isotypic_reps.keys(): + ordered_isotypic_reps.move_to_end(symm_group.trivial_representation.id, last=False) + + # Required permutation to change the order of the irreps. So we obtain irreps of the same type consecutively. + oneline_permutation = [] + for irrep_id, iso_rep in ordered_isotypic_reps.items(): + idx = isotypic_subspaces_indices[irrep_id] + oneline_permutation.extend(idx) + oneline_permutation = np.concatenate(oneline_permutation) + P_in2iso = permutation_matrix(oneline_permutation) + + Q_iso = rep.change_of_basis_inv @ P_in2iso.T + rep_iso_basis = directsum(list(ordered_isotypic_reps.values()), + name=rep.name + '-Iso', + change_of_basis=Q_iso) + + iso_supported_nonlinearities = [iso_rep.supported_nonlinearities for iso_rep in ordered_isotypic_reps.values()] + rep_iso_basis.supported_nonlinearities = functools.reduce(set.intersection, iso_supported_nonlinearities) + rep_iso_basis.attributes['isotypic_reps'] = ordered_isotypic_reps + + return rep_iso_basis diff --git a/morpho_symm/utils/robot_utils.py b/morpho_symm/utils/robot_utils.py index f120645..af2f002 100644 --- a/morpho_symm/utils/robot_utils.py +++ b/morpho_symm/utils/robot_utils.py @@ -17,7 +17,7 @@ import morpho_symm from morpho_symm.robots.PinBulletWrapper import PinBulletWrapper from morpho_symm.utils.algebra_utils import gen_permutation_matrix -from morpho_symm.utils.group_utils import group_rep_from_gens +from morpho_symm.utils.rep_theory_utils import group_rep_from_gens from morpho_symm.utils.pybullet_visual_utils import ( change_robot_appearance, configure_bullet_simulation, diff --git a/pyproject.toml b/pyproject.toml index a545582..bebdb27 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ allow-direct-references = true # We use custom branch of escnn with some devel [project] name = "morpho_symm" -version = "0.1.1" +version = "0.1.2" keywords = ["morphological symmetry", "locomotion", "dynamical systems", "robot symmetries", "symmetry"] description = "Tools for the identification, study, and exploitation of morphological symmetries in locomoting dynamical systems" readme = "README.md" @@ -47,7 +47,6 @@ dependencies = [ learning = [ "pytorch-lightning>=1.7.7", "torch>=1.12.1", - "emlp", ] [project.urls]