Skip to content

Commit

Permalink
Better handling of Matlab RingParam element (#741)
Browse files Browse the repository at this point in the history
identify the RingParam placehoder and removr it when saving
  • Loading branch information
lfarv committed Feb 25, 2024
1 parent acda62e commit 099e47e
Showing 1 changed file with 93 additions and 67 deletions.
160 changes: 93 additions & 67 deletions pyat/at/load/matfile.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,41 @@
"""
Load lattices from Matlab files.
"""
from __future__ import print_function

from __future__ import print_function, annotations

__all__ = ["load_mat", "save_mat", "load_m", "save_m", "load_var"]

import sys
from os.path import abspath, basename, splitext
from typing import Optional
from collections.abc import Sequence, Generator
from warnings import warn
from typing import Optional, Generator, Sequence
import scipy.io

import numpy
from ..lattice import elements, AtWarning, params_filter, AtError
from ..lattice import Element, Lattice
import scipy.io

from .allfiles import register_format
from .utils import element_from_dict, element_from_m, RingParam
from .utils import element_to_dict, element_to_m
from ..lattice import Element, Lattice
from ..lattice import elements, AtWarning, params_filter, AtError

__all__ = ['load_mat', 'save_mat', 'load_m', 'save_m',
'load_var']

_param_to_lattice = {'Energy': 'energy', 'Periodicity': 'periodicity',
'FamName': 'name', 'Particle': '_particle',
'cell_harmnumber': '_harmcell',
'HarmNumber': 'harmonic_number'}
_param_ignore = {'PassMethod', 'Length', 'cavpts'}
_param_to_lattice = {
"Energy": "energy",
"Periodicity": "periodicity",
"FamName": "name",
"Particle": "_particle",
"cell_harmnumber": "_harmcell",
"HarmNumber": "harmonic_number",
}
_param_ignore = {"PassMethod", "Length", "cavpts"}

# Python to Matlab attribute translation
_matattr_map = dict(((v, k) for k, v in _param_to_lattice.items()))


def matfile_generator(params: dict, mat_file: str)\
-> Generator[Element, None, None]:
def matfile_generator(params: dict, mat_file: str) -> Generator[Element, None, None]:
"""Run through Matlab cells and generate AT elements
Parameters:
Expand All @@ -48,10 +55,11 @@ def matfile_generator(params: dict, mat_file: str)\
Yields:
elem (Element): new Elements
"""

def mclean(data):
if data.dtype.type is numpy.str_:
# Convert strings in arrays back to strings.
return str(data[0]) if data.size > 0 else ''
return str(data[0]) if data.size > 0 else ""
elif data.size == 1:
v = data[0, 0]
if issubclass(v.dtype.type, numpy.void):
Expand All @@ -64,25 +72,28 @@ def mclean(data):
# Remove any surplus dimensions in arrays.
return numpy.squeeze(data)

m = scipy.io.loadmat(params.setdefault('mat_file', mat_file))
matvars = [varname for varname in m if not varname.startswith('__')]
default_key = matvars[0] if (len(matvars) == 1) else 'RING'
key = params.setdefault('mat_key', default_key)
# noinspection PyUnresolvedReferences
m = scipy.io.loadmat(params.setdefault("mat_file", mat_file))
matvars = [varname for varname in m if not varname.startswith("__")]
default_key = matvars[0] if (len(matvars) == 1) else "RING"
key = params.setdefault("mat_key", default_key)
if key not in m.keys():
kok = [k for k in m.keys() if '__' not in k]
raise AtError('Selected mat_key does not exist, '
'please select in: {}'.format(kok))
check = params.pop('check', True)
quiet = params.pop('quiet', False)
kok = [k for k in m.keys() if "__" not in k]
raise AtError(
"Selected mat_key does not exist, please select in: {}".format(kok)
)
check = params.pop("check", True)
quiet = params.pop("quiet", False)
cell_array = m[key].flat
for index, mat_elem in enumerate(cell_array):
elem = mat_elem[0, 0]
kwargs = {f: mclean(elem[f]) for f in elem.dtype.fields}
yield element_from_dict(kwargs, index=index, check=check, quiet=quiet)


def ringparam_filter(params: dict, elem_iterator, *args)\
-> Generator[Element, None, None]:
def ringparam_filter(
params: dict, elem_iterator, *args
) -> Generator[Element, None, None]:
"""Run through all elements, process and optionally removes
RingParam elements
Expand All @@ -109,12 +120,13 @@ def ringparam_filter(params: dict, elem_iterator, *args)\
* ``_particle``
* ``_radiation``
"""
keep_all = params.pop('keep_all', False)
keep_all = params.pop("keep_all", False)
ringparams = []
radiate = False
for elem in elem_iterator(params, *args):
if (elem.PassMethod.endswith('RadPass') or
elem.PassMethod.endswith('CavityPass')):
if elem.PassMethod.endswith("RadPass") or elem.PassMethod.endswith(
"CavityPass"
):
radiate = True
if isinstance(elem, RingParam):
ringparams.append(elem)
Expand All @@ -123,14 +135,14 @@ def ringparam_filter(params: dict, elem_iterator, *args)\
params.setdefault(_param_to_lattice.get(k, k.lower()), v)
if keep_all:
pars = vars(elem).copy()
name = pars.pop('FamName')
yield elements.Marker(name, **pars)
name = pars.pop("FamName")
yield elements.Marker(name, tag="RingParam", **pars)
else:
yield elem
params['_radiation'] = radiate
params["_radiation"] = radiate

if len(ringparams) > 1:
warn(AtWarning('More than 1 RingParam element, the 1st one is used'))
warn(AtWarning("More than 1 RingParam element, the 1st one is used"))


def load_mat(filename: str, **kwargs) -> Lattice:
Expand Down Expand Up @@ -163,15 +175,19 @@ def load_mat(filename: str, **kwargs) -> Lattice:
See Also:
:py:func:`.load_lattice` for a generic lattice-loading function.
"""
if 'key' in kwargs: # process the deprecated 'key' keyword
kwargs.setdefault('mat_key', kwargs.pop('key'))
return Lattice(ringparam_filter, matfile_generator, abspath(filename),
iterator=params_filter, **kwargs)


def mfile_generator(params: dict, m_file: str)\
-> Generator[Element, None, None]:
"""
if "key" in kwargs: # process the deprecated 'key' keyword
kwargs.setdefault("mat_key", kwargs.pop("key"))
return Lattice(
ringparam_filter,
matfile_generator,
abspath(filename),
iterator=params_filter,
**kwargs,
)


def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None]:
"""Run through the lines of a Matlab m-file and generate AT elements
Parameters:
Expand All @@ -180,20 +196,20 @@ def mfile_generator(params: dict, m_file: str)\
Yields:
elem (Element): new Elements
"""
with open(params.setdefault('m_file', m_file), 'rt') as file:
"""
with open(params.setdefault("m_file", m_file), "rt") as file:
_ = next(file) # Matlab function definition
_ = next(file) # Cell array opening
for lineno, line in enumerate(file):
if line.startswith('};'):
if line.startswith("};"):
break
try:
elem = element_from_m(line)
except ValueError:
warn(AtWarning('Invalid line {0} skipped.'.format(lineno)))
warn(AtWarning("Invalid line {0} skipped.".format(lineno)))
continue
except KeyError:
warn(AtWarning('Line {0}: Unknown class.'.format(lineno)))
warn(AtWarning("Line {0}: Unknown class.".format(lineno)))
continue
else:
yield elem
Expand Down Expand Up @@ -223,8 +239,13 @@ def load_m(filename: str, **kwargs) -> Lattice:
See Also:
:py:func:`.load_lattice` for a generic lattice-loading function.
"""
return Lattice(ringparam_filter, mfile_generator, abspath(filename),
iterator=params_filter, **kwargs)
return Lattice(
ringparam_filter,
mfile_generator,
abspath(filename),
iterator=params_filter,
**kwargs,
)


def load_var(matlat: Sequence[dict], **kwargs) -> Lattice:
Expand All @@ -248,28 +269,32 @@ def load_var(matlat: Sequence[dict], **kwargs) -> Lattice:
Returns:
lattice (Lattice): New :py:class:`.Lattice` object
"""

# noinspection PyUnusedLocal
def var_generator(params, latt):
for elem in latt:
yield element_from_dict(elem)

return Lattice(ringparam_filter, var_generator, matlat,
iterator=params_filter, **kwargs)
return Lattice(
ringparam_filter, var_generator, matlat, iterator=params_filter, **kwargs
)


def matlab_ring(ring) -> Generator[Element, None, None]:
def matlab_ring(ring: Lattice) -> Generator[Element, None, None]:
"""Prepend a RingParam element to a lattice"""
dct = dict((_matattr_map.get(k, k.title()), v)
for k, v in ring.attrs.items())
famname = dct.pop('FamName')
energy = dct.pop('Energy')
dct = dict((_matattr_map.get(k, k.title()), v) for k, v in ring.attrs.items())
famname = dct.pop("FamName")
energy = dct.pop("Energy")
yield RingParam(famname, energy, **dct)
for elem in ring:
yield elem
if not (
isinstance(elem, elements.Marker)
and getattr(elem, "tag", None) == "RingParam"
):
yield elem


def save_mat(ring: Lattice, filename: str,
mat_key: str = 'RING') -> None:
def save_mat(ring: Lattice, filename: str, mat_key: str = "RING") -> None:
"""Save a :py:class:`.Lattice` as a Matlab mat-file
Parameters:
Expand All @@ -282,6 +307,7 @@ def save_mat(ring: Lattice, filename: str,
:py:func:`.save_lattice` for a generic lattice-saving function.
"""
lring = tuple((element_to_dict(elem),) for elem in matlab_ring(ring))
# noinspection PyUnresolvedReferences
scipy.io.savemat(filename, {mat_key: lring}, long_field_names=True)


Expand All @@ -298,20 +324,20 @@ def save_m(ring: Lattice, filename: Optional[str] = None) -> None:
"""

def save(file):
print('ring = {...', file=file)
print("ring = {...", file=file)
for elem in matlab_ring(ring):
print(element_to_m(elem), file=file)
print('};', file=file)
print("};", file=file)

if filename is None:
save(sys.stdout)
else:
with open(filename, 'wt') as mfile:
with open(filename, "wt") as mfile:
[funcname, _] = splitext(basename(filename))
print('function ring = {0}()'.format(funcname), file=mfile)
print("function ring = {0}()".format(funcname), file=mfile)
save(mfile)
print('end', file=mfile)
print("end", file=mfile)


register_format('.mat', load_mat, save_mat, descr='Matlab binary mat-file')
register_format('.m', load_m, save_m, descr='Matlab text m-file')
register_format(".mat", load_mat, save_mat, descr="Matlab binary mat-file")
register_format(".m", load_m, save_m, descr="Matlab text m-file")

0 comments on commit 099e47e

Please sign in to comment.