Skip to content

Commit

Permalink
Fix saving lattices in m-files (#832)
Browse files Browse the repository at this point in the history
* Fix m-file writing

* blackened matfile.py
  • Loading branch information
lfarv authored Sep 9, 2024
1 parent ef2daa0 commit ac67bfa
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pyat/at/load/madx.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def beta() -> float:
if cavities:
cavities.sort(key=lambda el: el.Frequency)
c0 = cavities[0]
params["_harmnumber"] = getattr(c0, "HarmNumber", np.nan)
params["_cell_harmnumber"] = getattr(c0, "HarmNumber", np.nan)

part = kwargs.get("particle", None)
if isinstance(part, str):
Expand Down
69 changes: 42 additions & 27 deletions pyat/at/load/matfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
Load lattices from Matlab files.
"""

from __future__ import print_function, annotations
from __future__ import annotations

__all__ = ["load_mat", "save_mat", "load_m", "save_m", "load_var"]

import sys
import os
from os.path import abspath, basename, splitext
from typing import Optional, Any
from typing import Any
from collections.abc import Sequence, Generator
from warnings import warn

Expand Down Expand Up @@ -41,15 +41,15 @@
"Beam_Current": None,
"Nbunch": None,
}
_p2m = dict((v, k) for k, v in _m2p.items() if v is not None)
_p2m = {v: k for k, v in _m2p.items() if v is not None}
# Attribute to drop when writing a file
_p2m.update(_drop_attrs)

# Python to Matlab type translation
_mattype_map = {
int: float,
tuple: list,
np.ndarray: lambda attr: np.asanyarray(attr),
list: lambda attr: np.array(attr, dtype=object),
tuple: lambda attr: np.array(attr, dtype=object),
Particle: lambda attr: attr.to_dict(),
}
# Matlab constructor function
Expand Down Expand Up @@ -157,7 +157,10 @@ def ringparam_filter(
params["_radiation"] = radiate

if len(ringparams) > 1:
warn(AtWarning("More than 1 RingParam element, the 1st one is used"))
warn(
AtWarning("More than 1 RingParam element, the 1st one is used"),
stacklevel=2,
)


def load_mat(filename: str, **kwargs) -> Lattice:
Expand Down Expand Up @@ -297,7 +300,7 @@ def load_m(filename: str, **kwargs) -> Lattice:

def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None]:
"""Run through the lines of a Matlab m-file and generate AT elements"""
with open(params.setdefault("in_file", m_file), "rt") as file:
with open(params.setdefault("in_file", m_file)) as file:
_ = next(file) # Matlab function definition
_ = next(file) # Cell array opening
for lineno, line in enumerate(file):
Expand All @@ -306,10 +309,10 @@ def mfile_generator(params: dict, m_file: str) -> Generator[Element, None, None]
try:
elem = _element_from_m(line)
except ValueError:
warn(AtWarning("Invalid line {0} skipped.".format(lineno)))
warn(AtWarning(f"Invalid line {lineno} skipped."), stacklevel=2)
continue
except KeyError:
warn(AtWarning("Line {0}: Unknown class.".format(lineno)))
warn(AtWarning(f"Line {lineno}: Unknown class."), stacklevel=2)
continue
else:
yield elem
Expand Down Expand Up @@ -360,7 +363,7 @@ def matlab_ring(ring: Lattice) -> Generator[Element, None, None]:

def required(rng):
# Public lattice attributes
params = dict((k, v) for k, v in vars(rng).items() if not k.startswith("_"))
params = {k: v for k, v in vars(rng).items() if not k.startswith("_")}
# Output the required attributes/properties
for kp, km in _p2m.items():
try:
Expand Down Expand Up @@ -418,21 +421,29 @@ def scan(d):
yield convert(k)
yield convert(v)

return "struct({0})".format(", ".join(scan(pdir)))
return "struct({})".format(", ".join(scan(pdir)))

def convert_array(arr):
if arr.ndim > 1:
lns = (str(list(ln)).replace(",", "")[1:-1] for ln in arr)
return "".join(("[", "; ".join(lns), "]"))
return np.array2string(arg).replace("\n", ";")
elif arr.ndim > 0:
return str(list(arr)).replace(",", "")
return np.array2string(arg)
else:
return str(arr)

def convert_list(lst):
return f"{{{{{str(lst)[1:-1]}}}}}"

if isinstance(arg, np.ndarray):
return convert_array(arg)
elif isinstance(arg, np.number):
return str(arg)
elif isinstance(arg, dict):
return convert_dict(arg)
elif isinstance(arg, tuple):
return convert_list(arg)
elif isinstance(arg, list):
return convert_list(arg)
elif isinstance(arg, Particle):
return convert_dict(arg.to_dict())
else:
Expand All @@ -446,19 +457,19 @@ def m_name(elclass):
# noinspection PyProtectedMember
args = [attrs.pop(k, getattr(elem, k)) for k in elem._BUILD_ATTRIBUTES]
defelem = elem.__class__(*args)
kwds = dict(
(k, v)
kwds = {
k: v
for k, v in attrs.items()
if not np.array_equal(v, getattr(defelem, k, None))
)
}
argstrs = [convert(arg) for arg in args]
if "PassMethod" in kwds:
argstrs.append(convert(kwds.pop("PassMethod")))
argstrs += [", ".join((repr(k), convert(v))) for k, v in kwds.items()]
return "{0:>15}({1});...".format(m_name(elem.__class__), ", ".join(argstrs))
return "{:>15}({});...".format(m_name(elem.__class__), ", ".join(argstrs))


def save_m(ring: Lattice, filename: Optional[str] = None) -> None:
def save_m(ring: Lattice, filename: str | None = None) -> None:
"""Save a :py:class:`.Lattice` as a Matlab m-file
Parameters:
Expand All @@ -471,18 +482,21 @@ def save_m(ring: Lattice, filename: Optional[str] = None) -> None:
"""

def save(file):
print("ring = {...", file=file)
for elem in matlab_ring(ring):
print(_element_to_m(elem), file=file)
print("};", file=file)
with np.printoptions(linewidth=1000, floatmode="unique"):
print("ring = {...", file=file)
for elem in matlab_ring(ring):
print(_element_to_m(elem), file=file)
print("};", file=file)

if filename is None:
save(sys.stdout)
else:
with open(filename, "wt") as mfile:
with open(filename, "w") as mfile:
[funcname, _] = splitext(basename(filename))
print("function ring = {0}()".format(funcname), file=mfile)
print(f"function ring = {funcname}()", file=mfile)
save(mfile)
print(" function v=False()\n v=false;\n end", file=mfile)
print(" function v=True()\n v=true;\n end", file=mfile)
print("end", file=mfile)


Expand All @@ -492,7 +506,7 @@ def _mat_file(ring):
try:
in_file = ring.in_file
except AttributeError:
raise AttributeError("'Lattice' object has no attribute 'mat_file'")
raise AttributeError("'Lattice' object has no attribute 'mat_file'") from None
if isinstance(in_file, str):
_, ext = os.path.splitext(in_file)
if ext != ".mat":
Expand All @@ -507,10 +521,11 @@ def _mat_key(ring):
try:
mat_key = ring.use
except AttributeError:
raise AttributeError("'Lattice' object has no attribute 'mat_key'")
raise AttributeError("'Lattice' object has no attribute 'mat_key'") from None
return mat_key


# noinspection PyUnusedLocal
def _ignore(ring, value):
pass

Expand Down

0 comments on commit ac67bfa

Please sign in to comment.