Skip to content

Commit

Permalink
G-Invariant nn construction
Browse files Browse the repository at this point in the history
This commit enables the "easy" parametric construction of G-invariant MLPs. Achieving G-invariance without paying a high cost on the expresivity of the resultant Neural Network is a bit tricky and not properly handeled in ESCNN.

In this commit we introduce:
  - An Equivariant Module to perform the Isotypic Decomposition of any vector space
  - On the Isotypic Basis we efficiently compute G-invariant features from a symmetric (latent) vector space, by taking the norm of the projection of the latent features to each G-stable irreducible subpsace associated with an irreducible representation. These are by construction G-invariant features, since we are dealing with "Point-Groups" which do leave the origin of these G-stable subspaces invariant.
  • Loading branch information
Danfoa committed Oct 25, 2023
1 parent 3b4ad44 commit 270bc9d
Show file tree
Hide file tree
Showing 8 changed files with 468 additions and 619 deletions.
234 changes: 128 additions & 106 deletions morpho_symm/nn/EMLP.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import logging
from typing import List, Union
from typing import Union

import escnn
import numpy as np
import torch
from escnn.nn import EquivariantModule, FieldType

from morpho_symm.nn.EquivariantModules import IsotypicBasis

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -33,7 +35,13 @@ def __init__(self,
Args:
----
in_type (escnn.nn.FieldType): Input field type containing the representation of the input space.
out_type (escnn.nn.FieldType): Output field type containing the representation of the output space.
out_type (escnn.nn.FieldType): Output field type containing the representation of the output space. If the
output type representation is composed only of multiples of the trivial representation then the this model
will default to a G-invariant function. This G-invariant function is composed of a G-equivariant feature
extractor, a G-invariant pooling layer, extracting invariant features from the equivariant features, and a
last linear (unconstrained) layer to map the invariant features to the output space.
TODO: Create a G-invariant EMLP class where we can control the network processing the G-invariant
features instead of defaulting to a single linear layer.
num_hidden_units: Number of hidden units in the intermediate layers. The effective number of hidden units
will be ceil(num_hidden_units/|G|). Since we assume intermediate embeddings are regular fields.
activation (escnn.nn.EquivariantModule, str): Name of pointwise activation function to use.
Expand All @@ -51,8 +59,21 @@ def __init__(self,
self.in_type, self.out_type = in_type, out_type
self.gspace = self.in_type.gspace
self.group = self.gspace.fibergroup

self.num_layers = num_layers

# Check if the network is a G-invariant function (i.e., out rep is composed only of the trivial representation)
out_irreps = set(out_type.representation.irreps)
if len(out_irreps) == 1 and self.group.trivial_representation.id == list(out_irreps)[0]:
self.invariant_fn = True
else:
self.invariant_fn = False
input_irreps = set(in_type.representation.irreps)
inner_irreps = set(out_type.irreps)
diff = input_irreps.symmetric_difference(inner_irreps)
if len(diff) > 0:
log.warning(f"Irreps {list(diff)} of group {self.gspace.fibergroup} are not in the input/output types."
f"This represents an information bottleneck. Consider extracting invariant features.")

if self.num_layers == 1 and not head_with_activation:
log.warning(f"{self} model with 1 layer and no activation. This is equivalent to a linear map")

Expand All @@ -66,15 +87,7 @@ def __init__(self,
else:
raise ValueError(f"Activation type {type(activation)} not supported.")

input_irreps = set(in_type.representation.irreps)
inner_irreps = set(out_type.irreps)
diff = input_irreps.symmetric_difference(inner_irreps)
if len(diff) > 0:
log.warning(f"Irreps {list(diff)} of group {self.gspace.fibergroup} are not in the input/output types."
f"This represents an information bottleneck. Consider extracting invariant features.")

layer_in_type = in_type

self.net = escnn.nn.SequentialModule()
for n in range(self.num_layers - 1):
layer_out_type = hidden_type
Expand All @@ -89,22 +102,53 @@ def __init__(self,
layer_in_type = layer_out_type

# Add final layer
head_block = escnn.nn.SequentialModule()
head_block.add_module(f"linear_{num_layers - 1}", escnn.nn.Linear(layer_in_type, out_type, bias=bias))
if head_with_activation:
if batch_norm:
head_block.add_module(f"batchnorm_{num_layers - 1}", escnn.nn.IIDBatchNorm1d(out_type)),
head_block.add_module(f"act_{num_layers - 1}", activation)
# head_layer.check_equivariance()
self.net.add_module("head", head_block)
self.net_head = None
if self.invariant_fn:
self.net_head = torch.nn.Sequential()
# TODO: Make the G-invariant pooling with Isotypic Basis a stand alone module.
# Module describing the change of basis to an Isotypic Basis required for efficient G-invariant pooling
self.change2isotypic_basis = IsotypicBasis(hidden_type)
# Number of G-invariant features from net output equals the number of G-stable subspaces.
num_inv_features = len(hidden_type.irreps)
self.net_head.add_module(f"linear_{num_layers - 1}",
torch.nn.Linear(num_inv_features, out_type.size, bias=bias))
if head_with_activation:
if batch_norm:
self.net_head.add_module(f"batchnorm_{num_layers - 1}", torch.nn.BatchNorm1d(out_type.size)),
self.net_head.add_module(f"act_{num_layers - 1}", activation)
else: # Equivariant Network
self.net_head = escnn.nn.SequentialModule()
self.net_head.add_module(f"linear_{num_layers - 1}", escnn.nn.Linear(layer_in_type, out_type, bias=bias))
if head_with_activation:
if batch_norm:
self.net_head.add_module(f"batchnorm_{num_layers - 1}", escnn.nn.IIDBatchNorm1d(out_type)),
self.net_head.add_module(f"act_{num_layers - 1}", activation)
# Test the entire model is equivariant.
# self.net.check_equivariance()

def forward(self, x):
"""Forward pass of the EMLP model."""
equivariant_features = self.net(x)
if self.invariant_fn:
iso_equivariant_features = self.change2isotypic_basis(equivariant_features)
invariant_features = self.irrep_norm_pooling(iso_equivariant_features.tensor, iso_equivariant_features.type)
output = self.net_head(invariant_features)
output = self.out_type(output) # Wrap over invariant field type
else:
output = self.net_head(equivariant_features)
return output

def reset_parameters(self, init_mode=None):
"""Initialize weights and biases of E-MLP model."""
raise NotImplementedError()

@staticmethod
def get_activation(activation, in_type: FieldType, desired_hidden_units: int):
gspace = in_type.gspace
group = gspace.fibergroup
grid_length = group.order() if not group.continuous else 20
if isinstance(group, escnn.group.DihedralGroup):
grid_length = grid_length // 2

unique_irreps = set(in_type.irreps)
unique_irreps_dim = sum([group.irrep(*id).size for id in set(in_type.irreps)])
Expand All @@ -122,97 +166,75 @@ def get_activation(activation, in_type: FieldType, desired_hidden_units: int):
type='regular' if not group.continuous else 'rand',
N=grid_length)

def forward(self, x):
"""Forward pass of the EMLP model."""
return self.net(x)

def reset_parameters(self, init_mode=None):
"""Initialize weights and biases of E-MLP model."""
raise NotImplementedError()
@staticmethod
def irrep_norm_pooling(x: torch.Tensor, field_type: FieldType) -> torch.Tensor:
n_inv_features = len(field_type.irreps)
# TODO: Ensure isotypic basis i.e irreps of the same type are consecutive to each other.
inv_features = []
for field_start, field_end, rep in zip(field_type.fields_start,
field_type.fields_end,
field_type.representations):
# Each field here represents a representation of an Isotypic Subspace. This rep is only composed of a single
# irrep type.
x_field = x[..., field_start:field_end]
num_G_stable_spaces = len(rep.irreps) # Number of G-invariant features = multiplicity of irrep
# Again this assumes we are already in an Isotypic basis
assert len(np.unique(rep.irreps, axis=0)) == 1, "This only works for now on the Isotypic Basis"
# This basis is useful because we can apply the norm in a vectorized way
# Reshape features to [batch, num_G_stable_spaces, num_features_per_G_stable_space]
x_field_p = torch.reshape(x_field, (x_field.shape[0], num_G_stable_spaces, -1))
# Compute G-invariant measures as the norm of the features in each G-stable space
inv_field_features = torch.norm(x_field_p, dim=-1)
# Append to the list of inv features
inv_features.append(inv_field_features)
# Concatenate all the invariant features
inv_features = torch.cat(inv_features, dim=-1)
assert inv_features.shape[-1] == n_inv_features, f"Expected {n_inv_features} got {inv_features.shape[-1]}"
return inv_features

def evaluate_output_shape(self, input_shape: tuple[int, ...]) -> tuple[int, ...]:
"""Returns the output shape of the model given an input shape."""
batch_size = input_shape[0]
return batch_size, self.out_type.size


class MLP(torch.nn.Module):
"""Standard baseline MLP. Representations and group are used for shapes only."""

def __init__(self, d_in, d_out, num_hidden_units=128, num_layers=3,
activation: Union[torch.nn.Module, List[torch.nn.Module]] = torch.nn.ReLU,
with_bias=True, init_mode="fan_in"):
"""Constructor of a Multi-Layer Perceptron (MLP) model.
This utility class allows to easily instanciate a G-equivariant MLP architecture.
Args:
d_in: Dimension of the input space.
d_out: Dimension of the output space.
num_hidden_units: Number of hidden units in the intermediate layers.
num_layers: Number of layers in the MLP including input and output/head layers. That is, the number of
activation (escnn.nn.EquivariantModule, list(escnn.nn.EquivariantModule)): If a single activation module is
provided it will be used for all layers except the output layer. If a list of activation modules is provided
then `num_layers` activation equivariant modules should be provided.
with_bias: Whether to include a bias term in the linear layers.
init_mode: Not used until now. Will be used to initialize the weights of the MLP
"""
super().__init__()
self.d_in = d_in
self.d_out = d_out
self.init_mode = init_mode
self.hidden_channels = num_hidden_units
self.activation = activation

logging.info("Initializing MLP")

dim_in = self.d_in
dim_out = num_hidden_units
self.net = torch.nn.Sequential()
for n in range(num_layers - 1):
dim_out = num_hidden_units
block = torch.nn.Sequential()
block.add_module(f"linear_{n}", torch.nn.Linear(dim_in, dim_out, bias=with_bias))
block.add_module(f"batchnorm_{n}", torch.nn.BatchNorm1d(dim_out))
block.add_module(f"act_{n}", activation())

self.net.add_module(f"block_{n}", block)
dim_in = dim_out
# Add last layer
linear_out = torch.nn.Linear(in_features=dim_out, out_features=self.d_out, bias=with_bias)
self.net.add_module("head", linear_out)

self.reset_parameters(init_mode=self.init_mode)

def forward(self, input):
output = self.net(input)
return output

def get_hparams(self):
return {'num_layers': len(self.net),
'hidden_ch': self.hidden_channels,
'init_mode': self.init_mode}

def reset_parameters(self, init_mode=None):
assert init_mode is not None or self.init_mode is not None
self.init_mode = self.init_mode if init_mode is None else init_mode
for module in self.net:
if isinstance(module, torch.nn.Sequential):
tensor = module[0].weight
activation = module[-1].__class__.__name__
elif isinstance(module, torch.nn.Linear):
tensor = module.weight
activation = "Linear"
else:
raise NotImplementedError(module.__class__.__name__)

if "fan_in" == self.init_mode or "fan_out" == self.init_mode:
torch.nn.init.kaiming_uniform_(tensor, mode=self.init_mode, nonlinearity=activation.lower())
elif 'normal' in self.init_mode.lower():
split = self.init_mode.split('l')
std = 0.1 if len(split) == 1 else float(split[1])
torch.nn.init.normal_(tensor, 0, std)
else:
raise NotImplementedError(self.init_mode)

log.info(f"MLP initialized with mode: {self.init_mode}")
if __name__ == "__main__":
G = escnn.group.DihedralGroup(6)
gspace = escnn.gspaces.no_base_space(G)
# Test Invariant EMLP
in_type = escnn.nn.FieldType(gspace, [G.regular_representation] * 5)
out_type = escnn.nn.FieldType(gspace, [G.trivial_representation] * 6)
emlp = EMLP(in_type, out_type,
num_hidden_units=128,
num_layers=3,
activation="ReLU",
head_with_activation=False)
emlp.eval() # Shut down batch norm
x = in_type(torch.randn(1, in_type.size))
y = emlp(x)

for g in G.elements:
g_x = in_type(in_type.transform_fibers(x.tensor, g)) # Compute g · x
g_y = emlp(g_x) # Compute g · y
assert torch.allclose(y.tensor, g_y.tensor, rtol=1e-4, atol=1e-4), \
f"{g} invariance failed {y.tensor} != {g_y.tensor}"

# Test Equivariant EMLP
in_type = escnn.nn.FieldType(gspace, [G.regular_representation] * 5)
out_type = escnn.nn.FieldType(gspace, [G.regular_representation] * 2)
emlp = EMLP(in_type, out_type,
num_hidden_units=128,
num_layers=3,
activation="ReLU",
head_with_activation=False)
emlp.eval() # Shut down batch norm

x = in_type(torch.randn(1, in_type.size))
y = emlp(x)

for g in G.elements:
g_x = in_type(in_type.transform_fibers(x.tensor, g)) # Compute g · x
g_y_gt = out_type(out_type.transform_fibers(y.tensor, g)) # Compute ground truth g · y
g_y = emlp(g_x) # Compute g · y
assert torch.allclose(g_y_gt.tensor, g_y.tensor, rtol=1e-4, atol=1e-4), \
f"{g} invariance failed {g_y_gt.tensor} != {g_y.tensor}"
Loading

0 comments on commit 270bc9d

Please sign in to comment.