-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
New multirep PR initiated #496
base: develop
Are you sure you want to change the base?
Changes from 80 commits
60a477f
e5ae472
6c56fa8
63bf520
9222ccd
38fb967
9d4195e
bc6782a
78e80ed
f79b6ca
c2b932c
f8fe74e
0f22fdf
aa7a438
60b912e
d4cdd95
fbb636e
96a525e
8ab3750
92df65f
b2deb23
43ee95b
615c5e3
698e426
9408efc
9190c4f
f8b14f7
c5578b8
364ac77
fe64727
5a51d0e
9458c67
5af0008
5cd6170
d0b8091
4174725
01791a2
218f3c8
8e4775c
81104f2
18b22e3
265b854
9095417
8ad4249
e07761e
b6d2a71
43d20cd
6557de6
ce28399
1bd62aa
e895788
0942b9c
d25a08a
7145fd7
c42e20f
56e88e5
4afbd1e
7497575
4f88920
f2d4f65
f843f04
3132357
3ef97b0
17629d3
af395f4
42f13e8
7a76877
a62200b
3c8261b
a02ffe8
62bc1b6
5b7ad47
fdc50d0
a8c7c17
4ba09e7
512072a
fdfbe02
932c9e7
af1552e
7316083
8a5c93e
7940a29
983c275
23acbda
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -25,6 +25,7 @@ | |||||
import numbers | ||||||
from functools import cached_property | ||||||
|
||||||
import copy | ||||||
import numpy as np | ||||||
from numpy.typing import ArrayLike | ||||||
import ipywidgets as widgets | ||||||
|
@@ -46,6 +47,7 @@ | |||||
from mrmustard.lab_dev.wires import Wires | ||||||
from mrmustard.physics.triples import identity_Abc | ||||||
|
||||||
|
||||||
__all__ = ["CircuitComponent"] | ||||||
|
||||||
|
||||||
|
@@ -109,6 +111,7 @@ def __init__( | |||||
) | ||||||
if self._representation: | ||||||
self._representation = self._representation.reorder(tuple(perm)) | ||||||
self._index_representation = {i: ("B", None) for i in self.wires.indices} | ||||||
|
||||||
def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]: | ||||||
""" | ||||||
|
@@ -168,6 +171,13 @@ def adjoint(self) -> CircuitComponent: | |||||
ret.short_name = self.short_name | ||||||
for param in self.parameter_set.all_parameters.values(): | ||||||
ret._add_parameter(param) | ||||||
|
||||||
# handling index representations: | ||||||
for i, j in enumerate(kets): | ||||||
ret._index_representation[i] = self._index_representation[j] | ||||||
for i, j in enumerate(bras): | ||||||
ret._index_representation[i + len(kets)] = self._index_representation[j] | ||||||
|
||||||
return ret | ||||||
|
||||||
@property | ||||||
|
@@ -187,6 +197,16 @@ def dual(self) -> CircuitComponent: | |||||
ret.short_name = self.short_name | ||||||
for param in self.parameter_set.all_parameters.values(): | ||||||
ret._add_parameter(param) | ||||||
|
||||||
# handling index representations: | ||||||
for i, j in enumerate(ib): | ||||||
ret._index_representation[i] = self._index_representation[j] | ||||||
for i, j in enumerate(ob): | ||||||
ret._index_representation[i + len(ib)] = self._index_representation[j] | ||||||
for i, j in enumerate(ik): | ||||||
ret._index_representation[i + len(ib + ob)] = self._index_representation[j] | ||||||
for i, j in enumerate(ok): | ||||||
ret._index_representation[i + len(ib + ob + ik)] = self._index_representation[j] | ||||||
return ret | ||||||
|
||||||
@cached_property | ||||||
|
@@ -423,6 +443,7 @@ def _from_attributes( | |||||
ret._name = name | ||||||
ret._representation = representation | ||||||
ret._wires = wires | ||||||
ret._index_representation = {i: ("B", None) for i in wires.indices} | ||||||
return ret | ||||||
return CircuitComponent(representation, wires, name) | ||||||
|
||||||
|
@@ -549,6 +570,8 @@ def on(self, modes: Sequence[int]) -> CircuitComponent: | |||||
modes_in_ket=set(modes) if ik else set(), | ||||||
) | ||||||
|
||||||
ret._index_representation = copy.deepcopy(self._index_representation) | ||||||
|
||||||
return ret | ||||||
|
||||||
def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: | ||||||
|
@@ -587,9 +610,10 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent: | |||||
del ret.manual_shape | ||||||
return ret | ||||||
|
||||||
def to_bargmann(self) -> CircuitComponent: | ||||||
def to_bargmann(self, indices: Sequence[int] | None = None) -> CircuitComponent: | ||||||
r""" | ||||||
Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation. | ||||||
Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation on the specified "indices." | ||||||
If "indices" are not specified, all indices are transformed into bargmann. | ||||||
.. code-block:: | ||||||
|
||||||
>>> from mrmustard.lab_dev import Dgate | ||||||
|
@@ -604,9 +628,17 @@ def to_bargmann(self) -> CircuitComponent: | |||||
>>> assert d_bargmann.wires == d.wires | ||||||
>>> assert isinstance(d_bargmann.representation, Bargmann) | ||||||
""" | ||||||
if isinstance(self.representation, Bargmann): | ||||||
return self | ||||||
else: | ||||||
apchytr marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
ret = copy.deepcopy(self) | ||||||
if isinstance(self.representation, Bargmann): # TODO: better name for Bargmann class | ||||||
# check cc rep | ||||||
if not indices: | ||||||
indices = self.wires.indices | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
ret = ret._apply_btoq_for_change_of_rep(indices) | ||||||
ret = ret._apply_btops_for_change_of_rep(indices) | ||||||
|
||||||
elif isinstance(self.representation, Fock): | ||||||
if self.representation.ansatz._original_abc_data: | ||||||
A, b, c = self.representation.ansatz._original_abc_data | ||||||
else: | ||||||
|
@@ -620,7 +652,7 @@ def to_bargmann(self) -> CircuitComponent: | |||||
ret = self._from_attributes(bargmann, self.wires, self.name) | ||||||
if "manual_shape" in ret.__dict__: | ||||||
del ret.manual_shape | ||||||
return ret | ||||||
return ret | ||||||
|
||||||
def _add_parameter(self, parameter: Constant | Variable): | ||||||
r""" | ||||||
|
@@ -650,6 +682,72 @@ def _getitem_builtin(self, modes: set[int]): | |||||
kwargs = self.parameter_set[items].to_dict() | ||||||
return self.__class__(modes=modes, **kwargs) | ||||||
|
||||||
def _apply_btops_for_change_of_rep(self, indices: Sequence[int]) -> CircuitComponent: | ||||||
r""" | ||||||
Helper function for change of representation in to_bargmann() | ||||||
|
||||||
Args: | ||||||
indices: the set of indices that we want to be represented in bargmann. | ||||||
|
||||||
Output: | ||||||
the cc object with Bargmann representation on the specified indices. The representations on the other wires remain intact. | ||||||
""" | ||||||
|
||||||
from .circuit_components_utils import BtoPS | ||||||
|
||||||
ret = copy.deepcopy(self) | ||||||
|
||||||
for i in indices: | ||||||
name, arg = self._index_representation[i] | ||||||
|
||||||
if name == "PS": | ||||||
ret._index_representation[i] = ("B", None) | ||||||
m = self.wires.index_to_mode_dict[i] | ||||||
if i in self.wires.output.bra.indices: | ||||||
if m not in self.wires.output.ket.modes: | ||||||
raise ValueError( | ||||||
f"The object does not have a consistent representation. Mode {m} with PS representation has appeared only on the output bra." | ||||||
) | ||||||
friend_index = self.wires.index_dicts[2][m] | ||||||
ret._index_representation[friend_index] = ("B", None) | ||||||
ret = ret @ BtoPS([m], s=arg).adjoint.inverse() | ||||||
|
||||||
if i in self.wires.input.bra.indices: | ||||||
if m not in self.wires.input.ket.modes: | ||||||
raise ValueError( | ||||||
f"The object does not have a consistent representation. Mode {m} with PS representation has appeared only on the input bra." | ||||||
) | ||||||
friend_index = self.wires.index_dicts[3][m] | ||||||
ret._index_representation[friend_index] = ("B", None) | ||||||
ret = BtoPS([m], s=arg).dual.inverse() @ ret | ||||||
|
||||||
return ret | ||||||
|
||||||
def _apply_btoq_for_change_of_rep(self, indices: Sequence[int]) -> CircuitComponent: | ||||||
r""" | ||||||
Helper function for change of representation in to_bargmann() | ||||||
""" | ||||||
|
||||||
from .circuit_components_utils import BtoQ | ||||||
|
||||||
ret = copy.deepcopy(self) | ||||||
for i in indices: | ||||||
name, arg = self._index_representation[i] | ||||||
if name == "Q": | ||||||
ret._index_representation[i] = ("B", None) # perhaps not needed -- can be removed | ||||||
if i in self.wires.output.bra.indices: | ||||||
ret = ret @ BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).adjoint.inverse() | ||||||
if i in self.wires.output.ket.indices: | ||||||
ret = ret @ BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).inverse() | ||||||
if i in self.wires.input.bra.indices: | ||||||
ret = ( | ||||||
BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).dual.adjoint.inverse() | ||||||
@ ret | ||||||
) | ||||||
if i in self.wires.input.ket.indices: | ||||||
ret = BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).dual.inverse() @ ret | ||||||
return ret | ||||||
|
||||||
def _light_copy(self, wires: Wires | None = None) -> CircuitComponent: | ||||||
r""" | ||||||
Creates a "light" copy of this component by referencing its __dict__, except for the wires, | ||||||
|
@@ -690,8 +788,11 @@ def __add__(self, other: CircuitComponent) -> CircuitComponent: | |||||
""" | ||||||
if self.wires != other.wires: | ||||||
raise ValueError("Cannot add components with different wires.") | ||||||
rep = self.representation + other.representation | ||||||
rep = ( | ||||||
self.to_bargmann().representation + other.to_bargmann().representation | ||||||
) # addition occurs in bargmann always | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need it in bargmann? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is to avoid adding things with different representations. Lmk if you have a better approach in mind. |
||||||
name = self.name if self.name == other.name else "" | ||||||
# TODO: go back to bargmann on all modes | ||||||
return self._from_attributes(rep, self.wires, name) | ||||||
|
||||||
def __eq__(self, other) -> bool: | ||||||
|
@@ -700,7 +801,16 @@ def __eq__(self, other) -> bool: | |||||
|
||||||
Compares representations and wires, but not the other attributes (e.g. name and parameter set). | ||||||
""" | ||||||
return self.representation == other.representation and self.wires == other.wires | ||||||
from .circuit_components_utils import BtoQ, BtoPS | ||||||
|
||||||
if (type(self.representation) == type(other.representation) == Fock) or isinstance( | ||||||
self, (BtoQ, BtoPS) | ||||||
): | ||||||
return self.representation == other.representation and self.wires == other.wires | ||||||
else: | ||||||
self_rep = self.to_bargmann().representation | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how I feel about this why do we need this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well.. two things are equal if they are equal in the same representation. I mean, in physics, we call them the same thing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But why do we need to call to_bargmann in the
|
||||||
other_rep = other.to_bargmann().representation | ||||||
return self_rep == other_rep and self.wires == other.wires | ||||||
|
||||||
def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: | ||||||
r""" | ||||||
|
@@ -719,27 +829,96 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent: | |||||
>>> att = Attenuator([0], 0.5) | ||||||
>>> assert (coh @ att).wires.input.bra # the input bra is still uncontracted | ||||||
""" | ||||||
from .circuit_components_utils import BtoQ, BtoPS | ||||||
|
||||||
if isinstance(other, (numbers.Number, np.ndarray)): | ||||||
return self * other | ||||||
|
||||||
wires_result, perm = self.wires @ other.wires | ||||||
idx_z, idx_zconj = self._matmul_indices(other) | ||||||
if type(self.representation) is type(other.representation): | ||||||
|
||||||
if type(self.representation) is type(other.representation) is Fock: | ||||||
self_rep = self.representation | ||||||
other_rep = other.representation | ||||||
else: | ||||||
self_rep = self.to_bargmann().representation | ||||||
other_rep = other.to_bargmann().representation | ||||||
if ( | ||||||
(not isinstance(self, BtoQ)) | ||||||
and (not isinstance(other, BtoQ)) | ||||||
and (not isinstance(self, BtoPS)) | ||||||
and (not isinstance(other, BtoPS)) | ||||||
): | ||||||
self_copy = copy.deepcopy(self) | ||||||
other_copy = copy.deepcopy(other) | ||||||
index_self, index_other = self_copy.wires.contracted_indices(other_copy.wires) | ||||||
self_rep = self_copy.to_bargmann( | ||||||
index_self | ||||||
).representation # this is where the copy is required (to not send the intial objects back to Bargmann) | ||||||
other_rep = other_copy.to_bargmann(index_other).representation | ||||||
else: | ||||||
self_rep = self.representation | ||||||
other_rep = other.representation | ||||||
|
||||||
rep = self_rep[idx_z] @ other_rep[idx_zconj] | ||||||
rep = rep.reorder(perm) if perm else rep | ||||||
return CircuitComponent._from_attributes(rep, wires_result, None) | ||||||
result = CircuitComponent._from_attributes(rep, wires_result, None) | ||||||
|
||||||
# REMEMBER the representations: | ||||||
# set the index_representation of uncontracted indices: | ||||||
# (this will be overwritten if we have a change of representation e.g. other == BtoQ) | ||||||
for m in other.wires.output.bra.modes: | ||||||
i = result.wires.index_dicts[0][m] | ||||||
j = other.wires.index_dicts[0][m] | ||||||
result._index_representation[i] = other._index_representation[j] | ||||||
for m in other.wires.output.ket.modes: | ||||||
i = result.wires.index_dicts[2][m] | ||||||
j = other.wires.index_dicts[2][m] | ||||||
result._index_representation[i] = other._index_representation[j] | ||||||
|
||||||
for m in self.wires.input.bra.modes: | ||||||
i = result.wires.index_dicts[1][m] | ||||||
j = self.wires.index_dicts[1][m] | ||||||
result._index_representation[i] = self._index_representation[j] | ||||||
for m in self.wires.input.ket.modes: | ||||||
i = result.wires.index_dicts[3][m] | ||||||
j = self.wires.index_dicts[3][m] | ||||||
result._index_representation[i] = self._index_representation[j] | ||||||
|
||||||
# now we check for indices that might have been contracted: | ||||||
idx_1, idx_2 = self.wires.contracted_indices(other.wires) | ||||||
|
||||||
for m in other.wires.input.bra.modes: | ||||||
j = other.wires.index_dicts[1][m] | ||||||
if j not in idx_2: | ||||||
i = result.wires.index_dicts[1][m] | ||||||
result._index_representation[i] = other._index_representation[j] | ||||||
for m in other.wires.input.ket.modes: | ||||||
|
||||||
j = other.wires.index_dicts[3][m] | ||||||
if j not in idx_2: | ||||||
i = result.wires.index_dicts[3][m] | ||||||
result._index_representation[i] = other._index_representation[j] | ||||||
|
||||||
for m in self.wires.output.bra.modes: | ||||||
j = self.wires.index_dicts[0][m] | ||||||
if j not in idx_1: | ||||||
i = result.wires.index_dicts[0][m] | ||||||
result._index_representation[i] = self._index_representation[j] | ||||||
|
||||||
for m in self.wires.output.ket.modes: | ||||||
j = self.wires.index_dicts[2][m] | ||||||
if j not in idx_1: | ||||||
i = result.wires.index_dicts[2][m] | ||||||
result._index_representation[i] = self._index_representation[j] | ||||||
|
||||||
return result | ||||||
|
||||||
def __mul__(self, other: Scalar) -> CircuitComponent: | ||||||
r""" | ||||||
Implements the multiplication by a scalar from the right. | ||||||
""" | ||||||
return self._from_attributes(self.representation * other, self.wires, self.name) | ||||||
ret = self._from_attributes(self.representation * other, self.wires, self.name) | ||||||
ret._index_representation = self._index_representation | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you need to copy? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I understand the question. Can you please elaborate? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same issue you encountered before where the solution was to deepcopy the dictionary. As in, |
||||||
return ret | ||||||
|
||||||
def __repr__(self) -> str: | ||||||
repr = self.representation | ||||||
|
@@ -828,6 +1007,7 @@ def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitCompone | |||||
msg = f"``>>`` not supported between {self} and {other} because it's not clear " | ||||||
msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components." | ||||||
raise ValueError(msg) | ||||||
|
||||||
return self._rshift_return(ret) | ||||||
|
||||||
def __sub__(self, other: CircuitComponent) -> CircuitComponent: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you need to incorporate
_index_representation
into the serialization of the object?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm.. not sure about that. But I guess it makes some sense to keep
_index_representation
as a hidden attribute since there are many calls to circuit component initializer and I might be missing correcting some when doing the change.