Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: avoid redundant data when saving Rotator models #229

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/models/cross/test_hilbert_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_fit(mca_model):
mca_rotator = HilbertMCARotator(n_modes=2)
mca_rotator.fit(mca_model)

assert hasattr(mca_rotator, "model")
assert hasattr(mca_rotator, "model_data")
assert hasattr(mca_rotator, "data")


Expand Down
2 changes: 1 addition & 1 deletion tests/models/cross/test_mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_fit(mca_model):
mca_rotator = MCARotator(n_modes=4)
mca_rotator.fit(mca_model)

assert hasattr(mca_rotator, "model")
assert hasattr(mca_rotator, "model_data")
assert hasattr(mca_rotator, "data")


Expand Down
6 changes: 3 additions & 3 deletions tests/models/single/test_eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ def test_fit(eof_model):
eof_rotator.fit(eof_model)

assert hasattr(
eof_rotator, "model"
), 'The attribute "model" should be populated after fitting.'
eof_rotator, "model_data"
), 'The attribute "model_data" should be populated after fitting.'
assert hasattr(
eof_rotator, "data"
), 'The attribute "data" should be populated after fitting.'
assert isinstance(eof_rotator.model, EOF)
assert isinstance(eof_rotator.model_data, DataContainer)
assert isinstance(eof_rotator.data, DataContainer)


Expand Down
6 changes: 3 additions & 3 deletions tests/models/single/test_hilbert_eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def test_fit(ceof_model):
ceof_rotator.fit(ceof_model)

assert hasattr(
ceof_rotator, "model"
), 'The attribute "model" should be populated after fitting.'
ceof_rotator, "model_data"
), 'The attribute "model_data" should be populated after fitting.'
assert hasattr(
ceof_rotator, "data"
), 'The attribute "data" should be populated after fitting.'
assert isinstance(ceof_rotator.model, HilbertEOF)
assert isinstance(ceof_rotator.model_data, DataContainer)
assert isinstance(ceof_rotator.data, DataContainer)


Expand Down
46 changes: 23 additions & 23 deletions xeofs/cross/cpcca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,35 +98,34 @@ def __init__(
self.whitener1 = Whitener()
self.whitener2 = Whitener()
self.data = DataContainer()
self.model = CPCCA()
self.model_data = DataContainer()

self.sorted = False

def get_serialization_attrs(self) -> dict:
return dict(
data=self.data,
model_data=self.model_data,
preprocessor1=self.preprocessor1,
preprocessor2=self.preprocessor2,
whitener1=self.whitener1,
whitener2=self.whitener2,
model=self.model,
sorted=self.sorted,
sample_name=self.sample_name,
feature_name=self.feature_name,
)

def _fit_algorithm(self, model) -> Self:
self.model = model
self.preprocessor1 = model.preprocessor1
self.preprocessor2 = model.preprocessor2
self.whitener1 = model.whitener1
self.whitener2 = model.whitener2
self.sample_name = self.model.sample_name
self.feature_name = self.model.feature_name
self.sample_name = model.sample_name
self.feature_name = model.feature_name
self.sorted = False

common_feature_dim = "common_feature_dim"
feature_name = self._get_feature_name()
feature_name = model.feature_name

n_modes = self._params["n_modes"]
power = self._params["power"]
Expand All @@ -145,12 +144,12 @@ def _fit_algorithm(self, model) -> Self:
# fraction" which is conserved under rotation, but does not have a clear
# interpretation as the term covariance fraction is only correct when
# both data sets X and Y are equal and MCA reduces to PCA.
svalues = self.model.data["singular_values"].sel(mode=slice(1, n_modes))
svalues = model.data["singular_values"].sel(mode=slice(1, n_modes))
scaling = np.sqrt(svalues)

# Get unrotated singular vectors
Qx = self.model.data["components1"].sel(mode=slice(1, n_modes))
Qy = self.model.data["components2"].sel(mode=slice(1, n_modes))
Qx = model.data["components1"].sel(mode=slice(1, n_modes))
Qy = model.data["components2"].sel(mode=slice(1, n_modes))

# Unwhiten and back-transform into physical space
Qx = self.whitener1.inverse_transform_components(Qx)
Expand Down Expand Up @@ -233,8 +232,8 @@ def _fit_algorithm(self, model) -> Self:
idx_modes_sorted.coords.update(squared_covariance.coords)

# Rotate scores using rotation matrix
scores1 = self.model.data["scores1"].sel(mode=slice(1, n_modes))
scores2 = self.model.data["scores2"].sel(mode=slice(1, n_modes))
scores1 = model.data["scores1"].sel(mode=slice(1, n_modes))
scores2 = model.data["scores2"].sel(mode=slice(1, n_modes))

scores1 = self.whitener1.inverse_transform_scores(scores1)
scores2 = self.whitener2.inverse_transform_scores(scores2)
Expand All @@ -260,12 +259,18 @@ def _fit_algorithm(self, model) -> Self:
scores1_rot = scores1_rot * modes_sign
scores2_rot = scores2_rot * modes_sign

# Create data container
# Create data container for Rotator and original model data
self.model_data.add(name="singular_values", data=model.data["singular_values"])
self.model_data.add(name="components1", data=model.data["components1"])
self.model_data.add(name="components2", data=model.data["components2"])

# Assigning input data to the Rotator object allows us to inherit some functionalities from the original model
# like squared_covariance_fraction(), homogeneous_patterns() etc.
self.data.add(
name="input_data1", data=self.model.data["input_data1"], allow_compute=False
name="input_data1", data=model.data["input_data1"], allow_compute=False
)
self.data.add(
name="input_data2", data=self.model.data["input_data2"], allow_compute=False
name="input_data2", data=model.data["input_data2"], allow_compute=False
)
self.data.add(name="components1", data=Qx_rot)
self.data.add(name="components2", data=Qy_rot)
Expand All @@ -274,7 +279,7 @@ def _fit_algorithm(self, model) -> Self:
self.data.add(name="squared_covariance", data=squared_covariance)
self.data.add(
name="total_squared_covariance",
data=self.model.data["total_squared_covariance"],
data=model.data["total_squared_covariance"],
)

self.data.add(name="idx_modes_sorted", data=idx_modes_sorted)
Expand Down Expand Up @@ -337,14 +342,14 @@ def transform(
)
RinvT = RinvT.rename({"mode_n": "mode"})

scaling = self.model.data["singular_values"].sel(mode=slice(1, n_modes))
scaling = self.model_data["singular_values"].sel(mode=slice(1, n_modes))
scaling = np.sqrt(scaling)

results = []

if X is not None:
# Select the (non-rotated) singular vectors of the first dataset
comps1 = self.model.data["components1"].sel(mode=slice(1, n_modes))
comps1 = self.model_data["components1"].sel(mode=slice(1, n_modes))

# Preprocess the data
comps1 = self.whitener1.inverse_transform_components(comps1)
Expand Down Expand Up @@ -374,7 +379,7 @@ def transform(

if Y is not None:
# Select the (non-rotated) singular vectors of the second dataset
comps2 = self.model.data["components2"].sel(mode=slice(1, n_modes))
comps2 = self.model_data["components2"].sel(mode=slice(1, n_modes))

# Preprocess the data
comps2 = self.whitener2.inverse_transform_components(comps2)
Expand Down Expand Up @@ -451,9 +456,6 @@ def _compute_rot_mat_inv_trans(self, rotation_matrix, input_dims) -> xr.DataArra
rotation_matrix = rotation_matrix.conj().transpose(*input_dims)
return rotation_matrix

def _get_feature_name(self):
return self.model.feature_name


class ComplexCPCCARotator(CPCCARotator, ComplexCPCCA):
"""Rotate a solution obtained from ``xe.cross.ComplexCPCCA``.
Expand Down Expand Up @@ -517,7 +519,6 @@ class ComplexCPCCARotator(CPCCARotator, ComplexCPCCA):
def __init__(self, **kwargs):
CPCCARotator.__init__(self, **kwargs)
self.attrs.update({"model": "Rotated Complex CPCCA"})
self.model = ComplexCPCCA()


class HilbertCPCCARotator(ComplexCPCCARotator, HilbertCPCCA):
Expand Down Expand Up @@ -582,7 +583,6 @@ class HilbertCPCCARotator(ComplexCPCCARotator, HilbertCPCCA):
def __init__(self, **kwargs):
ComplexCPCCARotator.__init__(self, **kwargs)
self.attrs.update({"model": "Rotated Hilbert CPCCA"})
self.model = HilbertCPCCA()

def transform(
self, X: DataObject | None = None, Y: DataObject | None = None, normalized=False
Expand Down
3 changes: 0 additions & 3 deletions xeofs/cross/mca_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ def __init__(

# Define analysis-relevant meta data
self.attrs.update({"model": "Rotated MCA"})
self.model = MCA()


class ComplexMCARotator(ComplexCPCCARotator, ComplexMCA):
Expand Down Expand Up @@ -149,7 +148,6 @@ def __init__(
compute=compute,
)
self.attrs.update({"model": "Rotated Complex MCA"})
self.model = ComplexMCA()


class HilbertMCARotator(HilbertCPCCARotator, HilbertMCA):
Expand Down Expand Up @@ -226,4 +224,3 @@ def __init__(
compute=compute,
)
self.attrs.update({"model": "Rotated Hilbert MCA"})
self.model = HilbertMCA()
17 changes: 10 additions & 7 deletions xeofs/single/eof_rotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,15 @@ def __init__(
# Attach empty objects
self.preprocessor = Preprocessor()
self.data = DataContainer()
self.model = EOF()
self.model_data = DataContainer()

self.sorted = False

def get_serialization_attrs(self) -> dict:
return dict(
data=self.data,
preprocessor=self.preprocessor,
model=self.model,
model_data=self.model_data,
sorted=self.sorted,
)

Expand All @@ -117,7 +117,6 @@ def fit(self, model) -> Self:
return self

def _fit_algorithm(self, model) -> Self:
self.model = model
self.preprocessor = model.preprocessor
self.sample_name = model.sample_name
self.feature_name = model.feature_name
Expand Down Expand Up @@ -189,6 +188,10 @@ def _fit_algorithm(self, model) -> Self:
scores = scores * modes_sign

# Store the results
self.model_data.add(model.data["norms"], "singular_values")
self.model_data.add(model.data["components"], "components")

# Assigning input data to the Rotator object allows us to inherit some functionalities from the original model
self.data.add(model.data["input_data"], "input_data", allow_compute=False)
self.data.add(rot_components, "components")
self.data.add(scores, "scores")
Expand Down Expand Up @@ -224,10 +227,12 @@ def _sort_by_variance(self):
def _transform_algorithm(self, X: DataArray) -> DataArray:
n_modes = self._params["n_modes"]

svals = self.model.singular_values().sel(mode=slice(1, self._params["n_modes"]))
svals = self.model_data["singular_values"].sel(
mode=slice(1, self._params["n_modes"])
)
pseudo_norms = self.data["norms"]
# Select the (non-rotated) singular vectors of the first dataset
components = self.model.data["components"].sel(mode=slice(1, n_modes))
components = self.model_data["components"].sel(mode=slice(1, n_modes))

# Compute non-rotated scores by projecting the data onto non-rotated components
projections = xr.dot(X, components) / svals
Expand Down Expand Up @@ -329,7 +334,6 @@ def __init__(
n_modes=n_modes, power=power, max_iter=max_iter, rtol=rtol, compute=compute
)
self.attrs.update({"model": "Rotated Complex EOF analysis"})
self.model = ComplexEOF()


class HilbertEOFRotator(EOFRotator, HilbertEOF):
Expand Down Expand Up @@ -385,7 +389,6 @@ def __init__(
n_modes=n_modes, power=power, max_iter=max_iter, rtol=rtol, compute=compute
)
self.attrs.update({"model": "Rotated Hilbert EOF analysis"})
self.model = HilbertEOF()

def _transform_algorithm(self, data: DataArray) -> DataArray:
# Here we leverage the Method Resolution Order (MRO) to invoke the
Expand Down
Loading