Skip to content

Commit

Permalink
Preserve the case of RingParam attribute names (#744)
Browse files Browse the repository at this point in the history
* Don't change the case of Lattice unexpected attributes

* Reorganise the processing of lattice attributes

* Cleaning of warnings and PEP8 check
  • Loading branch information
lfarv authored Mar 11, 2024
1 parent 099e47e commit f77cd65
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 230 deletions.
146 changes: 70 additions & 76 deletions pyat/at/lattice/lattice_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
As an example, see the at.physics.orbit module
"""
from __future__ import annotations

__all__ = ['Lattice', 'Filter', 'type_filter', 'params_filter', 'lattice_filter',
'elem_generator', 'no_filter']

import sys
import copy
import numpy
import math
from typing import Optional, Union
if sys.version_info.minor < 9:
Expand All @@ -23,17 +26,18 @@
from typing import SupportsIndex
from collections.abc import Callable, Iterable, Generator
from warnings import warn
from ..constants import clight, e_mass

import numpy

from . import elements as elt
from .elements import Element
from .particle_object import Particle
from .utils import AtError, AtWarning, Refpts
from .utils import get_s_pos, get_elements,get_value_refpts, set_value_refpts
# noinspection PyProtectedMember
from .utils import get_uint32_index, get_bool_index, _refcount, Uint32Refpts
from .utils import refpts_iterator, checktype
from .utils import get_s_pos, get_elements
from .utils import get_value_refpts, set_value_refpts
from .utils import set_shift, set_tilt, get_geometry
from . import elements as elt
from .elements import Element
from .utils import refpts_iterator, checktype, set_shift, set_tilt, get_geometry
from ..constants import clight, e_mass

_TWO_PI_ERROR = 1.E-4
Filter = Callable[..., Iterable[Element]]
Expand Down Expand Up @@ -69,10 +73,7 @@
)
}

__all__ = ['Lattice', 'type_filter', 'params_filter', 'lattice_filter',
'elem_generator', 'no_filter']

# Don't warn on floating-pont errors
# Don't warn on floating-point errors
numpy.seterr(divide='ignore', invalid='ignore')


Expand All @@ -96,9 +97,7 @@ class Lattice(list):
'_fillpattern')

# noinspection PyUnusedLocal
def __init__(self, *args,
iterator: Filter = None,
scan: bool = False, **kwargs):
def __init__(self, *args, iterator: Filter = None, scan: bool = False, **kwargs):
"""
Lattice(elements, **params)
Lattice(filter, [filter, ...,] iterator=iter,**params)
Expand Down Expand Up @@ -181,7 +180,7 @@ def __init__(self, *args,
energy and periodicity if not yet defined.
"""
if iterator is None:
arg1, = args or [[]] # accept 0 or 1 argument
(arg1,) = args or [[]] # accept 0 or 1 argument
if isinstance(arg1, Lattice):
elems = lattice_filter(kwargs, arg1)
else:
Expand All @@ -195,25 +194,21 @@ def __init__(self, *args,
for attr in self._excluded_attributes:
kwargs.pop(attr, None)
# set default values
kwargs.setdefault('name', '')
periodicity = kwargs.setdefault('periodicity', 1)
kwargs.setdefault('_particle', Particle())
kwargs.setdefault("name", "")
periodicity = kwargs.setdefault("periodicity", 1)
kwargs.setdefault("particle", kwargs.pop("_particle", Particle()))
kwargs.setdefault("beam_current", kwargs.pop("_beam_current", 0.0))
# dummy initialization in case the harmonic number is not there
kwargs.setdefault('_fillpattern', numpy.ones(1))
kwargs.setdefault("_fillpattern", numpy.ones(1))
# Remove temporary keywords
frequency = kwargs.pop('_frequency', None)
cell_length = kwargs.pop('_length', None)
cell_h = kwargs.pop('_harmnumber', math.nan)
ring_h = kwargs.pop('harmonic_number', periodicity*cell_h)
bcurrent = kwargs.pop('beam_current', 0.0)
kwargs.setdefault('_beam_current', bcurrent)

if 'energy' in kwargs:
kwargs.pop('_energy', None)
elif '_energy' not in kwargs:
raise AtError('Lattice energy is not defined')
if 'particle' in kwargs:
kwargs.pop('_particle', None)
frequency: Optional[float] = kwargs.pop("_frequency", None)
cell_length: Optional[float] = kwargs.pop("_length", None)
cell_h = kwargs.pop("cell_harmnumber", kwargs.pop("_cell_harmnumber", math.nan))
ring_h = kwargs.pop("harmonic_number", cell_h * periodicity)

energy = kwargs.setdefault("energy", kwargs.pop("_energy", None))
if energy is None:
raise AtError("Lattice energy is not defined")

# set attributes
self.update(kwargs)
Expand All @@ -223,12 +218,12 @@ def __init__(self, *args,
rev = self.beta * clight / cell_length
self._cell_harmnumber = int(round(frequency / rev))
try:
fp = kwargs.pop('_fillpattern', numpy.ones(1))
fp = kwargs.pop("_fillpattern", numpy.ones(1))
self.set_fillpattern(bunches=fp)
except AssertionError:
self.set_fillpattern()
elif not math.isnan(ring_h):
self.harmonic_number = ring_h
self._cell_harmnumber = ring_h / periodicity

def __getitem__(self, key):
try: # Integer
Expand Down Expand Up @@ -329,7 +324,8 @@ def extend(self, elems: Iterable[Element], copy_elements=False):
r"""This method adds all the elements of `elems` to the end of the
lattice. The behavior is the same as for a :py:obj:`list`
Equivalents syntaxes:
Equivalent syntaxes:
>>> ring.extend(elems)
>>> ring += elems
Expand All @@ -347,38 +343,39 @@ def extend(self, elems: Iterable[Element], copy_elements=False):
super().extend(elems)

def append(self, elem: Element, copy_elements=False):
r"""This method overwrites the inherited method
:py:meth:`list.append()`,
it behavior is changed, it accepts only AT lattice elements
:py:obj:`Element` as input argument.
Equivalents syntaxes:
>>> ring.append(elem)
>>> ring += [elem]
Parameters:
elem (Element): AT element to be appended to the lattice
copy_elements(bool): Default :py:obj:`True`.
If :py:obj:`True` a deep copy of elem
is used
r"""This method overwrites the inherited method :py:meth:`list.append()`,
its behavior is changed, it accepts only AT lattice elements
:py:obj:`Element` as input argument.
Equivalent syntaxes:
>>> ring.append(elem)
>>> ring += [elem]
Parameters:
elem (Element): AT element to be appended to the lattice
copy_elements(bool): Default :py:obj:`True`.
If :py:obj:`True` a deep copy of elem
is used
"""
self.extend([elem], copy_elements=copy_elements)

def repeat(self, n: int, copy_elements=True):
def repeat(self, n: int, copy_elements: bool = True):
r"""This method allows to repeat the lattice `n` times.
If `n` does not divide `ring.periodicity`, the new ring
periodicity is set to 1, otherwise it is et to
periodicity is set to 1, otherwise it is set to
`ring.periodicity /= n`.
Equivalents syntaxes:
Equivalent syntaxes:
>>> newring = ring.repeat(n)
>>> newring = ring * n
Parameters:
n (int): number of repetition
copy_elements(bool): Default :py:obj:`True`.
If :py:obj:`True` deepcopies of the
lattice are used for the repetition
n : number of repetitions
copy_elements: If :py:obj:`True`, deep copies of the lattice are used for
the repetition. Otherwise, the original elements are repeated in the
developed lattice.
Returns:
newring (Lattice): the new repeated lattice
Expand All @@ -397,21 +394,21 @@ def copy_fun(elem, copy):
warn(AtWarning('Non-integer number of cells: {}/{}. Periodi'
'city set to 1'.format(self.periodicity, n)))
periodicity = 1
hdict = dict(periodicity=periodicity)
try:
cell_h = self._cell_harmnumber
hdict.update(harmonic_number=self.cell_harmnumber*n*periodicity)
except AttributeError:
hdict = {}
else:
hdict = dict(_cell_harmnumber=n*cell_h)
pass
elems = (copy_fun(el, copy_elements) for _ in range(n) for el in self)
return Lattice(elem_generator, elems, iterator=self.attrs_filter,
periodicity=periodicity, **hdict)
**hdict)

def concatenate(self, *lattices: Iterable[Element],
copy_elements=False, copy=False):
"""Concatenate several `Iterable[Element]` with the lattice
Equivalents syntaxes:
Equivalent syntaxes:
>>> newring = ring.concatenate(r1, r2, r3, copy=True)
>>> newring = ring + r1 + r2 + r3
Expand Down Expand Up @@ -464,20 +461,19 @@ def reverse(self, copy=False):
reversed_list = list(elems)
self[:] = reversed_list

def develop(self) -> Lattice:
def develop(self, copy_elements: bool = True) -> Lattice:
"""Develop a periodical lattice by repeating its elements
*self.periodicity* times
The elements of the new lattice are deep copies ot the original
elements, so that they are all independent.
Parameters:
copy_elements: If :py:obj:`True`, deep copies of the elements are used for
the repetition. Otherwise, the original elements are repeated in the
developed lattice.
Returns:
newlattice: The developed lattice
"""
elist = (el.deepcopy() for _ in range(self.periodicity) for el in self)
return Lattice(elem_generator, elist,
iterator=self.attrs_filter, periodicity=1,
harmonic_number=self.harmonic_number)
return self.repeat(self.periodicity, copy_elements=copy_elements)

@property
def attrs(self) -> dict:
Expand Down Expand Up @@ -1415,11 +1411,10 @@ def elem_generator(params, elems: Iterable[Element]) -> Iterable[Element]:
return elems


no_filter = elem_generator # provided for backward compatibility
no_filter: Filter = elem_generator # provided for backward compatibility


def type_filter(params, elems: Iterable[Element]) \
-> Generator[Element, None, None]:
def type_filter(params, elems: Iterable[Element]) -> Generator[Element, None, None]:
"""Run through all elements and check element validity.
Analyse elements for radiation state
Expand All @@ -1429,7 +1424,7 @@ def type_filter(params, elems: Iterable[Element]) \
Yields:
lattice ``Elements``
"""
"""
radiate = False
for idx, elem in enumerate(elems):
if isinstance(elem, Element):
Expand All @@ -1442,8 +1437,7 @@ def type_filter(params, elems: Iterable[Element]) \
params['_radiation'] = radiate


def params_filter(params, elem_filter: Filter, *args) \
-> Generator[Element, None, None]:
def params_filter(params, elem_filter: Filter, *args) -> Generator[Element, None, None]:
"""Run through all elements, looking for energy and periodicity.
Remove the Energy attribute of non-radiating elements
Expand Down
47 changes: 31 additions & 16 deletions pyat/at/load/matfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import sys
from os.path import abspath, basename, splitext
from typing import Optional
from typing import Optional, Any
from collections.abc import Sequence, Generator
from warnings import warn

Expand All @@ -18,24 +18,25 @@
from .allfiles import register_format
from .utils import element_from_dict, element_from_m, RingParam
from .utils import element_to_dict, element_to_m
from ..lattice import Element, Lattice
from ..lattice import Element, Lattice, Filter
from ..lattice import elements, AtWarning, params_filter, AtError

_param_to_lattice = {
_m2p = {
"FamName": "name",
"Energy": "energy",
"Periodicity": "periodicity",
"FamName": "name",
"Particle": "_particle",
"cell_harmnumber": "_harmcell",
"HarmNumber": "harmonic_number",
"Particle": "particle",
"cell_harmnumber": "cell_harmnumber",
}
_param_ignore = {"PassMethod", "Length", "cavpts"}

# Python to Matlab attribute translation
_matattr_map = dict(((v, k) for k, v in _param_to_lattice.items()))
# Python to Matlab
_p2m = {"name", "energy", "periodicity", "particle", "cell_harmnumber", "beam_current"}


def matfile_generator(params: dict, mat_file: str) -> Generator[Element, None, None]:
def matfile_generator(
params: dict[str, Any], mat_file: str
) -> Generator[Element, None, None]:
"""Run through Matlab cells and generate AT elements
Parameters:
Expand Down Expand Up @@ -92,7 +93,7 @@ def mclean(data):


def ringparam_filter(
params: dict, elem_iterator, *args
params: dict[str, Any], elem_iterator: Filter, *args
) -> Generator[Element, None, None]:
"""Run through all elements, process and optionally removes
RingParam elements
Expand Down Expand Up @@ -132,7 +133,7 @@ def ringparam_filter(
ringparams.append(elem)
for k, v in elem.items():
if k not in _param_ignore:
params.setdefault(_param_to_lattice.get(k, k.lower()), v)
params.setdefault(_m2p.get(k, k), v)
if keep_all:
pars = vars(elem).copy()
name = pars.pop("FamName")
Expand Down Expand Up @@ -282,10 +283,24 @@ def var_generator(params, latt):

def matlab_ring(ring: Lattice) -> Generator[Element, None, None]:
"""Prepend a RingParam element to a lattice"""
dct = dict((_matattr_map.get(k, k.title()), v) for k, v in ring.attrs.items())
famname = dct.pop("FamName")
energy = dct.pop("Energy")
yield RingParam(famname, energy, **dct)

def required(rng):
# Public lattice attributes
params = dict((k, v) for k, v in vars(rng).items() if not k.startswith("_"))
# Output the required attributes/properties
for k in _p2m:
try:
v = getattr(rng, k)
except AttributeError:
pass
else:
params.pop(k, None)
yield k, v
# Output the remaining attributes
yield from params.items()

dct = dict(required(ring))
yield RingParam(**dct)
for elem in ring:
if not (
isinstance(elem, elements.Marker)
Expand Down
3 changes: 1 addition & 2 deletions pyat/at/load/reprfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from at.load import register_format
from at.load.utils import element_from_string
# imports necessary in' globals()' for 'eval'
# noinspection PyUnresolvedReferences
from at.lattice import Particle
from at.lattice import Particle # noqa: F401

__all__ = ['load_repr', 'save_repr']

Expand Down
Loading

0 comments on commit f77cd65

Please sign in to comment.