Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New multirep PR initiated #496

Open
wants to merge 84 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 80 commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
60a477f
New multirep PR initiated
arsalan-motamedi Oct 2, 2024
e5ae472
working on the errors
arsalan-motamedi Oct 4, 2024
6c56fa8
updates from Anthony's ntoqInv branch
arsalan-motamedi Oct 4, 2024
63bf520
updating btoq and btops
arsalan-motamedi Oct 4, 2024
9222ccd
all CC tests pass :)
arsalan-motamedi Oct 7, 2024
38fb967
making codefactor happy
arsalan-motamedi Oct 7, 2024
9d4195e
conflict resolved
arsalan-motamedi Oct 7, 2024
bc6782a
not sure what's happening
arsalan-motamedi Oct 7, 2024
78e80ed
Docstrings added
arsalan-motamedi Oct 7, 2024
f79b6ca
minimal changes
arsalan-motamedi Oct 7, 2024
c2b932c
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
apchytr Oct 7, 2024
f8fe74e
fixing errors: BtoPS.s.value
arsalan-motamedi Oct 7, 2024
0f22fdf
fixing BtoPS errors
arsalan-motamedi Oct 7, 2024
aa7a438
BtoQ phi issues fixed
arsalan-motamedi Oct 7, 2024
60b912e
BtoPS s issues fixed
arsalan-motamedi Oct 7, 2024
d4cdd95
changing stuff regarding BtoQ/BtoPS
arsalan-motamedi Oct 9, 2024
fbb636e
formatting
arsalan-motamedi Oct 9, 2024
96a525e
Some weird issues with BtoQ and BtoPS (Constant / float / etc) fixed
arsalan-motamedi Oct 9, 2024
8ab3750
formatting
arsalan-motamedi Oct 9, 2024
92df65f
slight changes
arsalan-motamedi Oct 9, 2024
b2deb23
tests fixed
arsalan-motamedi Oct 10, 2024
43ee95b
removed a print command :)
arsalan-motamedi Oct 10, 2024
615c5e3
Adding the ability to handle BtoQ.inverse and update the representation
arsalan-motamedi Oct 10, 2024
698e426
WIP
arsalan-motamedi Oct 10, 2024
9408efc
Ket matmul revised
arsalan-motamedi Oct 10, 2024
9190c4f
formatting
arsalan-motamedi Oct 10, 2024
f8b14f7
deepcopy change
arsalan-motamedi Oct 11, 2024
c5578b8
removing print statements
arsalan-motamedi Oct 11, 2024
364ac77
Now it's safe to do dual
arsalan-motamedi Oct 11, 2024
fe64727
formatting
arsalan-motamedi Oct 11, 2024
5a51d0e
test to take care of dual matmul
arsalan-motamedi Oct 11, 2024
9458c67
merge conflict resolution
arsalan-motamedi Oct 11, 2024
5af0008
formatting
arsalan-motamedi Oct 11, 2024
5cd6170
the "on" method of circuitcomponent revised (with multirep)
arsalan-motamedi Oct 11, 2024
d0b8091
added adjoint
arsalan-motamedi Oct 11, 2024
4174725
some tests regarding adoint, dual, and "on"
arsalan-motamedi Oct 11, 2024
01791a2
formatting
arsalan-motamedi Oct 11, 2024
218f3c8
test improved
arsalan-motamedi Oct 11, 2024
8e4775c
tests + code improvements (added copying when doing to_bargmann in ma…
arsalan-motamedi Oct 11, 2024
81104f2
weird tests added for logic of matmul
arsalan-motamedi Oct 11, 2024
18b22e3
minor improvement
arsalan-motamedi Oct 11, 2024
265b854
monir
arsalan-motamedi Oct 11, 2024
9095417
added wire tests
arsalan-motamedi Oct 11, 2024
8ad4249
formatting
arsalan-motamedi Oct 11, 2024
e07761e
codefactor changes
arsalan-motamedi Oct 11, 2024
b6d2a71
making codefactor a but happier
arsalan-motamedi Oct 11, 2024
43d20cd
addition's test
arsalan-motamedi Oct 11, 2024
6557de6
Adding representation to BtoQ/BtoPS and removing if statements on the…
arsalan-motamedi Oct 11, 2024
ce28399
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
arsalan-motamedi Oct 15, 2024
1bd62aa
Improving btoq
arsalan-motamedi Oct 15, 2024
e895788
added btoq representation tests
arsalan-motamedi Oct 15, 2024
0942b9c
formatting
arsalan-motamedi Oct 15, 2024
d25a08a
making codefactor happy
arsalan-motamedi Oct 15, 2024
7145fd7
telling codefactor to let go of "too-many-instance-attributes" error
arsalan-motamedi Oct 15, 2024
c42e20f
just restarging
arsalan-motamedi Oct 15, 2024
56e88e5
idk what happened-- just brought it back to a few commits ago
arsalan-motamedi Oct 15, 2024
4afbd1e
making codefactor happy
arsalan-motamedi Oct 15, 2024
7497575
nothing important
arsalan-motamedi Oct 15, 2024
4f88920
formatting
arsalan-motamedi Oct 15, 2024
f2d4f65
Paring wire updates on BtoPS rshifts + DM rshift updated + neater for…
arsalan-motamedi Oct 15, 2024
f843f04
fixing a typo in DM rshift test
arsalan-motamedi Oct 15, 2024
3132357
fixing serialize issue
arsalan-motamedi Oct 15, 2024
3ef97b0
making codefactor happy
arsalan-motamedi Oct 15, 2024
17629d3
formatting
arsalan-motamedi Oct 15, 2024
af395f4
added to_quadrature tests for checking the change of representation
arsalan-motamedi Oct 15, 2024
42f13e8
formatting
arsalan-motamedi Oct 15, 2024
7a76877
some important changes (some error fixed bc I forgot to put .modes an…
arsalan-motamedi Oct 15, 2024
a62200b
a fatal error in btops inverse fixed
arsalan-motamedi Oct 15, 2024
3c8261b
btoq inverse improved
arsalan-motamedi Oct 15, 2024
a02ffe8
improving tests for codecov
arsalan-motamedi Oct 15, 2024
62bc1b6
added forgotten assert
arsalan-motamedi Oct 15, 2024
5b7ad47
now scalar product remembers the representation
arsalan-motamedi Oct 15, 2024
fdc50d0
reformatting
arsalan-motamedi Oct 15, 2024
a8c7c17
Addressing some of Anthony's comments
arsalan-motamedi Oct 16, 2024
4ba09e7
removed an extra file I had generated
arsalan-motamedi Oct 16, 2024
512072a
Merge branch 'develop' into multirep-v2
arsalan-motamedi Oct 16, 2024
fdfbe02
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
arsalan-motamedi Oct 16, 2024
932c9e7
_index_representation
apchytr Oct 17, 2024
af1552e
Merge branch 'develop' of https://github.com/XanaduAI/MrMustard into …
arsalan-motamedi Oct 17, 2024
7316083
removed redundancies
arsalan-motamedi Oct 17, 2024
8a5c93e
Corrected the issue of having a multiplication of BtoQ stuff
arsalan-motamedi Oct 22, 2024
7940a29
Final improvements
arsalan-motamedi Oct 22, 2024
983c275
btoq update
arsalan-motamedi Oct 22, 2024
23acbda
codefactor improvement
arsalan-motamedi Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 193 additions & 13 deletions mrmustard/lab_dev/circuit_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numbers
from functools import cached_property

import copy
import numpy as np
from numpy.typing import ArrayLike
import ipywidgets as widgets
Expand All @@ -46,6 +47,7 @@
from mrmustard.lab_dev.wires import Wires
from mrmustard.physics.triples import identity_Abc


__all__ = ["CircuitComponent"]


Expand Down Expand Up @@ -109,6 +111,7 @@ def __init__(
)
if self._representation:
self._representation = self._representation.reorder(tuple(perm))
self._index_representation = {i: ("B", None) for i in self.wires.indices}

def _serialize(self) -> tuple[dict[str, Any], dict[str, ArrayLike]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to incorporate _index_representation into the serialization of the object?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm.. not sure about that. But I guess it makes some sense to keep _index_representationas a hidden attribute since there are many calls to circuit component initializer and I might be missing correcting some when doing the change.

"""
Expand Down Expand Up @@ -168,6 +171,13 @@ def adjoint(self) -> CircuitComponent:
ret.short_name = self.short_name
for param in self.parameter_set.all_parameters.values():
ret._add_parameter(param)

# handling index representations:
for i, j in enumerate(kets):
ret._index_representation[i] = self._index_representation[j]
for i, j in enumerate(bras):
ret._index_representation[i + len(kets)] = self._index_representation[j]

return ret

@property
Expand All @@ -187,6 +197,16 @@ def dual(self) -> CircuitComponent:
ret.short_name = self.short_name
for param in self.parameter_set.all_parameters.values():
ret._add_parameter(param)

# handling index representations:
for i, j in enumerate(ib):
ret._index_representation[i] = self._index_representation[j]
for i, j in enumerate(ob):
ret._index_representation[i + len(ib)] = self._index_representation[j]
for i, j in enumerate(ik):
ret._index_representation[i + len(ib + ob)] = self._index_representation[j]
for i, j in enumerate(ok):
ret._index_representation[i + len(ib + ob + ik)] = self._index_representation[j]
return ret

@cached_property
Expand Down Expand Up @@ -423,6 +443,7 @@ def _from_attributes(
ret._name = name
ret._representation = representation
ret._wires = wires
ret._index_representation = {i: ("B", None) for i in wires.indices}
return ret
return CircuitComponent(representation, wires, name)

Expand Down Expand Up @@ -549,6 +570,8 @@ def on(self, modes: Sequence[int]) -> CircuitComponent:
modes_in_ket=set(modes) if ik else set(),
)

ret._index_representation = copy.deepcopy(self._index_representation)

return ret

def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent:
Expand Down Expand Up @@ -587,9 +610,10 @@ def to_fock(self, shape: int | Sequence[int] | None = None) -> CircuitComponent:
del ret.manual_shape
return ret

def to_bargmann(self) -> CircuitComponent:
def to_bargmann(self, indices: Sequence[int] | None = None) -> CircuitComponent:
r"""
Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation.
Returns a new circuit component with the same attributes as this and a ``Bargmann`` representation on the specified "indices."
If "indices" are not specified, all indices are transformed into bargmann.
.. code-block::

>>> from mrmustard.lab_dev import Dgate
Expand All @@ -604,9 +628,17 @@ def to_bargmann(self) -> CircuitComponent:
>>> assert d_bargmann.wires == d.wires
>>> assert isinstance(d_bargmann.representation, Bargmann)
"""
if isinstance(self.representation, Bargmann):
return self
else:
apchytr marked this conversation as resolved.
Show resolved Hide resolved

ret = copy.deepcopy(self)
if isinstance(self.representation, Bargmann): # TODO: better name for Bargmann class
# check cc rep
if not indices:
indices = self.wires.indices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
indices = self.wires.indices
indices = indices if indices else self.wires.indices


ret = ret._apply_btoq_for_change_of_rep(indices)
ret = ret._apply_btops_for_change_of_rep(indices)

elif isinstance(self.representation, Fock):
if self.representation.ansatz._original_abc_data:
A, b, c = self.representation.ansatz._original_abc_data
else:
Expand All @@ -620,7 +652,7 @@ def to_bargmann(self) -> CircuitComponent:
ret = self._from_attributes(bargmann, self.wires, self.name)
if "manual_shape" in ret.__dict__:
del ret.manual_shape
return ret
return ret

def _add_parameter(self, parameter: Constant | Variable):
r"""
Expand Down Expand Up @@ -650,6 +682,72 @@ def _getitem_builtin(self, modes: set[int]):
kwargs = self.parameter_set[items].to_dict()
return self.__class__(modes=modes, **kwargs)

def _apply_btops_for_change_of_rep(self, indices: Sequence[int]) -> CircuitComponent:
r"""
Helper function for change of representation in to_bargmann()

Args:
indices: the set of indices that we want to be represented in bargmann.

Output:
the cc object with Bargmann representation on the specified indices. The representations on the other wires remain intact.
"""

from .circuit_components_utils import BtoPS

ret = copy.deepcopy(self)

for i in indices:
name, arg = self._index_representation[i]

if name == "PS":
ret._index_representation[i] = ("B", None)
m = self.wires.index_to_mode_dict[i]
if i in self.wires.output.bra.indices:
if m not in self.wires.output.ket.modes:
raise ValueError(
f"The object does not have a consistent representation. Mode {m} with PS representation has appeared only on the output bra."
)
friend_index = self.wires.index_dicts[2][m]
ret._index_representation[friend_index] = ("B", None)
ret = ret @ BtoPS([m], s=arg).adjoint.inverse()

if i in self.wires.input.bra.indices:
if m not in self.wires.input.ket.modes:
raise ValueError(
f"The object does not have a consistent representation. Mode {m} with PS representation has appeared only on the input bra."
)
friend_index = self.wires.index_dicts[3][m]
ret._index_representation[friend_index] = ("B", None)
ret = BtoPS([m], s=arg).dual.inverse() @ ret

return ret

def _apply_btoq_for_change_of_rep(self, indices: Sequence[int]) -> CircuitComponent:
r"""
Helper function for change of representation in to_bargmann()
"""

from .circuit_components_utils import BtoQ

ret = copy.deepcopy(self)
for i in indices:
name, arg = self._index_representation[i]
if name == "Q":
ret._index_representation[i] = ("B", None) # perhaps not needed -- can be removed
if i in self.wires.output.bra.indices:
ret = ret @ BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).adjoint.inverse()
if i in self.wires.output.ket.indices:
ret = ret @ BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).inverse()
if i in self.wires.input.bra.indices:
ret = (
BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).dual.adjoint.inverse()
@ ret
)
if i in self.wires.input.ket.indices:
ret = BtoQ([self.wires.index_to_mode_dict[i]], phi=arg).dual.inverse() @ ret
return ret

def _light_copy(self, wires: Wires | None = None) -> CircuitComponent:
r"""
Creates a "light" copy of this component by referencing its __dict__, except for the wires,
Expand Down Expand Up @@ -690,8 +788,11 @@ def __add__(self, other: CircuitComponent) -> CircuitComponent:
"""
if self.wires != other.wires:
raise ValueError("Cannot add components with different wires.")
rep = self.representation + other.representation
rep = (
self.to_bargmann().representation + other.to_bargmann().representation
) # addition occurs in bargmann always
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need it in bargmann?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to avoid adding things with different representations. Lmk if you have a better approach in mind.

name = self.name if self.name == other.name else ""
# TODO: go back to bargmann on all modes
return self._from_attributes(rep, self.wires, name)

def __eq__(self, other) -> bool:
Expand All @@ -700,7 +801,16 @@ def __eq__(self, other) -> bool:

Compares representations and wires, but not the other attributes (e.g. name and parameter set).
"""
return self.representation == other.representation and self.wires == other.wires
from .circuit_components_utils import BtoQ, BtoPS

if (type(self.representation) == type(other.representation) == Fock) or isinstance(
self, (BtoQ, BtoPS)
):
return self.representation == other.representation and self.wires == other.wires
else:
self_rep = self.to_bargmann().representation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how I feel about this why do we need this change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well.. two things are equal if they are equal in the same representation. I mean, in physics, we call them the same thing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But why do we need to call to_bargmann in the __eq__? I would expect it to just be

def __eq__(self, other):
     if isinstance(other, CircuitComponent):
          return self.representation == other.representation and self.wires == other.wires
     return False 

other_rep = other.to_bargmann().representation
return self_rep == other_rep and self.wires == other.wires

def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent:
r"""
Expand All @@ -719,27 +829,96 @@ def __matmul__(self, other: CircuitComponent | Scalar) -> CircuitComponent:
>>> att = Attenuator([0], 0.5)
>>> assert (coh @ att).wires.input.bra # the input bra is still uncontracted
"""
from .circuit_components_utils import BtoQ, BtoPS

if isinstance(other, (numbers.Number, np.ndarray)):
return self * other

wires_result, perm = self.wires @ other.wires
idx_z, idx_zconj = self._matmul_indices(other)
if type(self.representation) is type(other.representation):

if type(self.representation) is type(other.representation) is Fock:
self_rep = self.representation
other_rep = other.representation
else:
self_rep = self.to_bargmann().representation
other_rep = other.to_bargmann().representation
if (
(not isinstance(self, BtoQ))
and (not isinstance(other, BtoQ))
and (not isinstance(self, BtoPS))
and (not isinstance(other, BtoPS))
):
self_copy = copy.deepcopy(self)
other_copy = copy.deepcopy(other)
index_self, index_other = self_copy.wires.contracted_indices(other_copy.wires)
self_rep = self_copy.to_bargmann(
index_self
).representation # this is where the copy is required (to not send the intial objects back to Bargmann)
other_rep = other_copy.to_bargmann(index_other).representation
else:
self_rep = self.representation
other_rep = other.representation

rep = self_rep[idx_z] @ other_rep[idx_zconj]
rep = rep.reorder(perm) if perm else rep
return CircuitComponent._from_attributes(rep, wires_result, None)
result = CircuitComponent._from_attributes(rep, wires_result, None)

# REMEMBER the representations:
# set the index_representation of uncontracted indices:
# (this will be overwritten if we have a change of representation e.g. other == BtoQ)
for m in other.wires.output.bra.modes:
i = result.wires.index_dicts[0][m]
j = other.wires.index_dicts[0][m]
result._index_representation[i] = other._index_representation[j]
for m in other.wires.output.ket.modes:
i = result.wires.index_dicts[2][m]
j = other.wires.index_dicts[2][m]
result._index_representation[i] = other._index_representation[j]

for m in self.wires.input.bra.modes:
i = result.wires.index_dicts[1][m]
j = self.wires.index_dicts[1][m]
result._index_representation[i] = self._index_representation[j]
for m in self.wires.input.ket.modes:
i = result.wires.index_dicts[3][m]
j = self.wires.index_dicts[3][m]
result._index_representation[i] = self._index_representation[j]

# now we check for indices that might have been contracted:
idx_1, idx_2 = self.wires.contracted_indices(other.wires)

for m in other.wires.input.bra.modes:
j = other.wires.index_dicts[1][m]
if j not in idx_2:
i = result.wires.index_dicts[1][m]
result._index_representation[i] = other._index_representation[j]
for m in other.wires.input.ket.modes:

j = other.wires.index_dicts[3][m]
if j not in idx_2:
i = result.wires.index_dicts[3][m]
result._index_representation[i] = other._index_representation[j]

for m in self.wires.output.bra.modes:
j = self.wires.index_dicts[0][m]
if j not in idx_1:
i = result.wires.index_dicts[0][m]
result._index_representation[i] = self._index_representation[j]

for m in self.wires.output.ket.modes:
j = self.wires.index_dicts[2][m]
if j not in idx_1:
i = result.wires.index_dicts[2][m]
result._index_representation[i] = self._index_representation[j]

return result

def __mul__(self, other: Scalar) -> CircuitComponent:
r"""
Implements the multiplication by a scalar from the right.
"""
return self._from_attributes(self.representation * other, self.wires, self.name)
ret = self._from_attributes(self.representation * other, self.wires, self.name)
ret._index_representation = self._index_representation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need to copy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand the question. Can you please elaborate?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue you encountered before where the solution was to deepcopy the dictionary. As in, ret._index_representation is now pointing to the dictionary of self._index_representation so any mutations to the latter propagate to the former

return ret

def __repr__(self) -> str:
repr = self.representation
Expand Down Expand Up @@ -828,6 +1007,7 @@ def __rshift__(self, other: CircuitComponent | numbers.Number) -> CircuitCompone
msg = f"``>>`` not supported between {self} and {other} because it's not clear "
msg += "whether or where to add bra wires. Use ``@`` instead and specify all the components."
raise ValueError(msg)

return self._rshift_return(ret)

def __sub__(self, other: CircuitComponent) -> CircuitComponent:
Expand Down
Loading
Loading