Skip to content

Commit

Permalink
Fix bugs on numpy<1.24
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Ordonez <[email protected]>
  • Loading branch information
Danfoa committed Aug 2, 2023
1 parent 861264a commit 39077e8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 22 deletions.
36 changes: 18 additions & 18 deletions morpho_symm/groups/isotypic_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import escnn
import networkx as nx
import numpy as np
import scipy
from escnn.group.group import Group, GroupElement
from escnn.group.representation import Representation
from networkx import Graph
Expand Down Expand Up @@ -70,7 +71,8 @@ def rep(g):
rep = representation
# Compute the dimension of the representation
n = rep(G.sample()).shape[0]
for g in G.elements:

for g in G.elements: # Ensure the representation is unitary/orthogonal
error = np.abs((rep(g) @ rep(g).conj().T) - np.eye(n))
assert np.allclose(error, 0), f"Rep {rep} is not unitary: rep(g)@rep(g)^H=\n{(rep(g) @ rep(g).conj().T)}"

Expand All @@ -80,10 +82,9 @@ def rep(g):
return [rep], np.eye(n)

# Eigen-decomposition of matrix `H = P·A·P^-1` reveals the G-invariant subspaces/eigenspaces of the representations.
eivals, eigvects = np.linalg.eigh(H)
eivals, eigvects = np.linalg.eigh(H, UPLO='L')
P = eigvects.conj().T
assert np.allclose(P.conj().T @ np.diag(eivals) @ P, H)
assert np.allclose(P @ P.conj().T, np.eye(n)), "P is not Unitary/Hermitian"

# Eigendcomposition is not guaranteed to block_diagonalize the representation. An additional permutation of the
# rows and columns od the representation might be needed to produce a Jordan block canonical form.
Expand All @@ -102,7 +103,7 @@ def rep(g):
graph = Graph()
graph.add_edges_from(set(edges))
connected_components = [sorted(list(comp)) for comp in nx.connected_components(graph)]
connected_components = sorted(connected_components, key=lambda x: len(x))
connected_components = sorted(connected_components, key=lambda x: (len(x), min(x))) # Impose a canonical order
# If connected components are not adjacent dimensions, say subrep_1_dims = [0,2] and subrep_2_dims = [1,3] then
# We permute them to get a jordan block canonical form. I.e. subrep_1_dims = [0,1] and subrep_2_dims = [2,3].
oneline_notation = list(itertools.chain.from_iterable([list(comp) for comp in connected_components]))
Expand All @@ -123,10 +124,11 @@ def rep(g):
# Transform the decomposed representation into the Jordan Cannonical Form (jcf)
jcf_rep = (PJ @ decomposed_reps[g] @ PJ.T)
# Check Jordan Cannonical Form TODO: Extract this to a utils. function
above_block, below_block = jcf_rep[0:block_start, block_start:block_end], jcf_rep[block_end:,
block_start:block_end]
left_block, right_block = jcf_rep[block_start:block_end, 0:block_start], jcf_rep[block_start:block_end,
block_end:]
above_block = jcf_rep[0:block_start, block_start:block_end]
below_block = jcf_rep[block_end:, block_start:block_end]
left_block = jcf_rep[block_start:block_end, 0:block_start]
right_block = jcf_rep[block_start:block_end, block_end:]

assert np.allclose(above_block, 0) or above_block.size == 0, "Non zero elements above block"
assert np.allclose(below_block, 0) or below_block.size == 0, "Non zero elements below block"
assert np.allclose(left_block, 0) or left_block.size == 0, "Non zero elements left of block"
Expand Down Expand Up @@ -196,14 +198,10 @@ def rep(g):
Q = P @ Q_external @ Q_internal

# Test isotypic decomposition.
assert np.allclose(Q @ np.linalg.inv(Q), np.eye(n)), "Q is not unitary."
assert np.allclose(Q @ Q.conj().T, np.eye(n)), "Q is not unitary."
for g in G.elements:
rep(g)
g_iso = block_diag(*[irrep[g] if isinstance(irrep, dict) else irrep(g) for irrep in sorted_irreps])
P.T @ g_iso @ P
Q_external.conj().T @ P.T @ g_iso @ P @ Q_external
Q_internal.conj().T @ Q_external.conj().T @ P.T @ g_iso @ P @ Q_external @ Q_internal
error = np.abs(g_iso - (Q @ rep(g) @ np.linalg.inv(Q)))
error = np.abs(g_iso - (Q @ rep(g) @ Q.conj().T))
assert np.allclose(error, 0), f"Q @ rep[g] @ Q^-1 != block_diag[irreps[g]], for g={g}. Error \n:{error}"

return sorted_irreps, Q
Expand Down Expand Up @@ -324,18 +322,19 @@ def rep(g):
for g in G.elements:
iso_re_g = block_diag(*[irrep(g) for irrep in escnn_real_irreps])
iso_cplx_g = block_diag(*[cplx_irrep[g] for cplx_irrep in cplx_irreps])
rec_iso_re_g = Q_iso_cplx2iso_re @ P @ iso_cplx_g @ (Q_iso_cplx2iso_re @ P).conj().T
rec_iso_re_g = (Q_iso_cplx2iso_re @ P) @ iso_cplx_g @ (Q_iso_cplx2iso_re @ P).conj().T
error = np.abs(iso_re_g - rec_iso_re_g)
assert np.isclose(error, 0).all(), "Error found in the conversion of Real irreps to Complex irreps"
assert np.isclose(error, 0).all(), "Error in the conversion of Real irreps to Complex irreps"

# Now we have an orthogonal transformation between the input `rep` and `iso_re`.
# | iso_cplx(g) |
# (Q_iso_cplx2iso_re @ P @ Q) @ rep(g) @ (Q^-1 @ P^-1 @ Q_iso_cplx2iso_re^-1) = Q_re @ rep(g) @ Q_re^-1 = iso_re(g)
Q_re = Q_iso_cplx2iso_re @ P @ Q

assert np.allclose(Q_re @ Q_re.conj().T, np.eye(Q_re.shape[0])), "Q_re is not an orthogonal transformation"
assert np.allclose(np.imag(Q_re), 0), "Q_re is not a real matrix"
if np.allclose(np.imag(Q_re), 0):
Q_re = np.real(Q_re) # Remove numerical noise and ensure rep(g) is of dtype: float instead of cfloat

Q_re = np.real(Q_re) # Remove numerical noise
# Then we have that `Q_re^-1 @ iso_re(g) @ Q_re = rep(g)`
reconstructed_rep = Representation(G, name="reconstructed", irreps=[irrep.id for irrep in escnn_real_irreps],
change_of_basis=Q_re.conj().T)
Expand All @@ -346,6 +345,7 @@ def rep(g):
error = np.abs(g_true - g_rec)
error[error < 1e-10] = 0
assert np.allclose(error, 0), f"Reconstructed rep do not match input rep. g={g}, error:\n{error}"
assert np.allclose(np.imag(g_rec), 0), f"Reconstructed rep not real for g={g}: \n{g_rec}"

return reconstructed_rep

Expand Down
8 changes: 4 additions & 4 deletions morpho_symm/robots/PinSimWrapper.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import logging
from abc import ABC, abstractmethod
from typing import Collection, Optional, Tuple, Union
from typing import Collection, List, Optional, Tuple, Union

import numpy as np
import scipy
Expand Down Expand Up @@ -408,8 +408,8 @@ def __repr__(self):
class JointWrapper:

def __init__(self, type: Union[str, int], idx_q: int, idx_v: int, nq: int, nv: int,
pos_limit_low: Union[list[float], float] = -np.inf,
pos_limit_high: Union[list[float], float] = np.inf):
pos_limit_low: Union[List[float], float] = -np.inf,
pos_limit_high: Union[List[float], float] = np.inf):
self.type = type
self.idx_q = idx_q
self.idx_v = idx_v
Expand Down Expand Up @@ -485,7 +485,7 @@ def substract_configuration(self, q1, q2) -> State:
raise NotImplementedError()

@property
def state_idx(self) -> Tuple[list[int], list[int]]:
def state_idx(self) -> Tuple[List[int], List[int]]:
idx_q = list(range(self.idx_q, self.idx_q + self.nq))
idx_v = list(range(self.idx_v, self.idx_v + self.nv))
return idx_q, idx_v
Expand Down

0 comments on commit 39077e8

Please sign in to comment.