From acc7afa5e4cb8ad9250f8765607b6e40c340bdfc Mon Sep 17 00:00:00 2001 From: Remco de Boer <29308176+redeboer@users.noreply.github.com> Date: Sat, 3 Sep 2022 23:44:40 +0200 Subject: [PATCH] MAINT: support SymPy v1.11 (#325) * DX: show MyST-NB execution traceback * MAINT: improve type hint consistency with SymPy v1.11 * MAINT: update pip constraints and pre-commit config Co-authored-by: GitHub Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .constraints/py3.10.txt | 18 +++++------ .constraints/py3.7.txt | 16 +++++----- .constraints/py3.8.txt | 18 +++++------ .constraints/py3.9.txt | 18 +++++------ .cspell.json | 1 + .github/workflows/ci-docs.yml | 21 ------------- .pre-commit-config.yaml | 2 +- docs/conf.py | 1 + setup.cfg | 2 +- src/ampform/dynamics/__init__.py | 4 +-- src/ampform/dynamics/phasespace.py | 40 ++++++++++++++----------- src/ampform/helicity/__init__.py | 4 +-- src/ampform/kinematics/__init__.py | 12 ++++---- src/ampform/sympy/__init__.py | 8 ++--- src/ampform/sympy/_array_expressions.py | 6 ++-- 15 files changed, 78 insertions(+), 93 deletions(-) diff --git a/.constraints/py3.10.txt b/.constraints/py3.10.txt index c94763cf5..002338a3b 100644 --- a/.constraints/py3.10.txt +++ b/.constraints/py3.10.txt @@ -63,7 +63,7 @@ ipykernel==6.15.2 ipympl==0.9.2 ipython==8.4.0 ipython-genutils==0.2.0 -ipywidgets==8.0.1 +ipywidgets==8.0.2 isort==5.10.1 jedi==0.18.1 jinja2==3.1.2 @@ -79,7 +79,7 @@ jupyterlab-markup==1.1.0 jupyterlab-myst==0.1.6 ; python_version >= "3.7.0" jupyterlab-pygments==0.2.2 jupyterlab-server==2.15.1 -jupyterlab-widgets==3.0.2 +jupyterlab-widgets==3.0.3 kiwisolver==1.4.4 latexcodec==2.0.1 lazy-object-proxy==1.7.1 @@ -114,7 +114,7 @@ packaging==21.3 pandocfilters==1.5.0 parso==0.8.3 particle==0.20.1 -pathspec==0.10.0 +pathspec==0.10.1 pep8-naming==0.13.2 pexpect==4.8.0 pickleshare==0.7.5 @@ -123,7 +123,7 @@ platformdirs==2.5.2 pluggy==1.0.0 pre-commit==2.20.0 prometheus-client==0.14.1 -prompt-toolkit==3.0.30 +prompt-toolkit==3.0.31 psutil==5.9.1 ptyprocess==0.7.0 pure-eval==0.2.2 @@ -140,7 +140,7 @@ pygments==2.13.0 pylint==2.15.0 pyparsing==3.0.9 pyrsistent==0.18.1 -pytest==7.1.2 +pytest==7.1.3 pytest-cov==3.0.0 pytest-forked==1.4.0 pytest-profiling==1.7.0 @@ -157,7 +157,7 @@ restructuredtext-lint==1.4.0 rich==12.5.1 send2trash==1.8.0 six==1.16.0 -sniffio==1.2.0 +sniffio==1.3.0 snowballstemmer==2.2.0 soupsieve==2.3.2.post1 sphinx==4.5.0 @@ -179,7 +179,7 @@ sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlalchemy==1.4.40 stack-data==0.5.0 -sympy==1.10.1 +sympy==1.11.1 tabulate==0.8.10 terminado==0.15.0 tinycss2==1.1.1 @@ -188,7 +188,7 @@ tomli==2.0.1 tomlkit==0.11.4 tornado==6.2 tox==3.25.1 -tqdm==4.64.0 +tqdm==4.64.1 traitlets==5.3.0 types-docutils==0.19.0 types-pkg-resources==0.1.3 @@ -202,7 +202,7 @@ wcwidth==0.2.5 webencodings==0.5.1 websocket-client==1.4.0 wheel==0.37.1 -widgetsnbextension==4.0.2 +widgetsnbextension==4.0.3 wrapt==1.14.1 zipp==3.8.1 diff --git a/.constraints/py3.7.txt b/.constraints/py3.7.txt index 27c59bc94..20f2dba7a 100644 --- a/.constraints/py3.7.txt +++ b/.constraints/py3.7.txt @@ -61,7 +61,7 @@ ipykernel==6.15.2 ipympl==0.9.2 ipython==7.34.0 ipython-genutils==0.2.0 -ipywidgets==8.0.1 +ipywidgets==8.0.2 isort==5.10.1 jedi==0.18.1 jinja2==3.1.2 @@ -77,7 +77,7 @@ jupyterlab-markup==1.1.0 jupyterlab-myst==0.1.6 ; python_version >= "3.7.0" jupyterlab-pygments==0.2.2 jupyterlab-server==2.15.1 -jupyterlab-widgets==3.0.2 +jupyterlab-widgets==3.0.3 kiwisolver==1.4.4 latexcodec==2.0.1 lazy-object-proxy==1.7.1 @@ -112,7 +112,7 @@ packaging==21.3 pandocfilters==1.5.0 parso==0.8.3 particle==0.20.1 -pathspec==0.10.0 +pathspec==0.10.1 pep8-naming==0.13.2 pexpect==4.8.0 pickleshare==0.7.5 @@ -122,7 +122,7 @@ platformdirs==2.5.2 pluggy==1.0.0 pre-commit==2.20.0 prometheus-client==0.14.1 -prompt-toolkit==3.0.30 +prompt-toolkit==3.0.31 psutil==5.9.1 ptyprocess==0.7.0 py==1.11.0 @@ -138,7 +138,7 @@ pygments==2.13.0 pylint==2.15.0 pyparsing==3.0.9 pyrsistent==0.18.1 -pytest==7.1.2 +pytest==7.1.3 pytest-cov==3.0.0 pytest-forked==1.4.0 pytest-profiling==1.7.0 @@ -156,7 +156,7 @@ rich==12.5.1 send2trash==1.8.0 singledispatchmethod==1.0 ; python_version < "3.8.0" six==1.16.0 -sniffio==1.2.0 +sniffio==1.3.0 snowballstemmer==2.2.0 soupsieve==2.3.2.post1 sphinx==4.3.2 ; python_version < "3.8.0" @@ -186,7 +186,7 @@ tomli==2.0.1 tomlkit==0.11.4 tornado==6.2 tox==3.25.1 -tqdm==4.64.0 +tqdm==4.64.1 traitlets==5.3.0 typed-ast==1.5.4 types-docutils==0.19.0 @@ -201,7 +201,7 @@ wcwidth==0.2.5 webencodings==0.5.1 websocket-client==1.4.0 wheel==0.37.1 -widgetsnbextension==4.0.2 +widgetsnbextension==4.0.3 wrapt==1.14.1 zipp==3.8.1 diff --git a/.constraints/py3.8.txt b/.constraints/py3.8.txt index f16f1b71f..be661928d 100644 --- a/.constraints/py3.8.txt +++ b/.constraints/py3.8.txt @@ -64,7 +64,7 @@ ipykernel==6.15.2 ipympl==0.9.2 ipython==8.4.0 ipython-genutils==0.2.0 -ipywidgets==8.0.1 +ipywidgets==8.0.2 isort==5.10.1 jedi==0.18.1 jinja2==3.1.2 @@ -80,7 +80,7 @@ jupyterlab-markup==1.1.0 jupyterlab-myst==0.1.6 ; python_version >= "3.7.0" jupyterlab-pygments==0.2.2 jupyterlab-server==2.15.1 -jupyterlab-widgets==3.0.2 +jupyterlab-widgets==3.0.3 kiwisolver==1.4.4 latexcodec==2.0.1 lazy-object-proxy==1.7.1 @@ -115,7 +115,7 @@ packaging==21.3 pandocfilters==1.5.0 parso==0.8.3 particle==0.20.1 -pathspec==0.10.0 +pathspec==0.10.1 pep8-naming==0.13.2 pexpect==4.8.0 pickleshare==0.7.5 @@ -125,7 +125,7 @@ platformdirs==2.5.2 pluggy==1.0.0 pre-commit==2.20.0 prometheus-client==0.14.1 -prompt-toolkit==3.0.30 +prompt-toolkit==3.0.31 psutil==5.9.1 ptyprocess==0.7.0 pure-eval==0.2.2 @@ -142,7 +142,7 @@ pygments==2.13.0 pylint==2.15.0 pyparsing==3.0.9 pyrsistent==0.18.1 -pytest==7.1.2 +pytest==7.1.3 pytest-cov==3.0.0 pytest-forked==1.4.0 pytest-profiling==1.7.0 @@ -159,7 +159,7 @@ restructuredtext-lint==1.4.0 rich==12.5.1 send2trash==1.8.0 six==1.16.0 -sniffio==1.2.0 +sniffio==1.3.0 snowballstemmer==2.2.0 soupsieve==2.3.2.post1 sphinx==4.5.0 @@ -181,7 +181,7 @@ sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlalchemy==1.4.40 stack-data==0.5.0 -sympy==1.10.1 +sympy==1.11.1 tabulate==0.8.10 terminado==0.15.0 tinycss2==1.1.1 @@ -190,7 +190,7 @@ tomli==2.0.1 tomlkit==0.11.4 tornado==6.2 tox==3.25.1 -tqdm==4.64.0 +tqdm==4.64.1 traitlets==5.3.0 types-docutils==0.19.0 types-pkg-resources==0.1.3 @@ -204,7 +204,7 @@ wcwidth==0.2.5 webencodings==0.5.1 websocket-client==1.4.0 wheel==0.37.1 -widgetsnbextension==4.0.2 +widgetsnbextension==4.0.3 wrapt==1.14.1 zipp==3.8.1 diff --git a/.constraints/py3.9.txt b/.constraints/py3.9.txt index 56064eaba..e78b4bc29 100644 --- a/.constraints/py3.9.txt +++ b/.constraints/py3.9.txt @@ -63,7 +63,7 @@ ipykernel==6.15.2 ipympl==0.9.2 ipython==8.4.0 ipython-genutils==0.2.0 -ipywidgets==8.0.1 +ipywidgets==8.0.2 isort==5.10.1 jedi==0.18.1 jinja2==3.1.2 @@ -79,7 +79,7 @@ jupyterlab-markup==1.1.0 jupyterlab-myst==0.1.6 ; python_version >= "3.7.0" jupyterlab-pygments==0.2.2 jupyterlab-server==2.15.1 -jupyterlab-widgets==3.0.2 +jupyterlab-widgets==3.0.3 kiwisolver==1.4.4 latexcodec==2.0.1 lazy-object-proxy==1.7.1 @@ -114,7 +114,7 @@ packaging==21.3 pandocfilters==1.5.0 parso==0.8.3 particle==0.20.1 -pathspec==0.10.0 +pathspec==0.10.1 pep8-naming==0.13.2 pexpect==4.8.0 pickleshare==0.7.5 @@ -123,7 +123,7 @@ platformdirs==2.5.2 pluggy==1.0.0 pre-commit==2.20.0 prometheus-client==0.14.1 -prompt-toolkit==3.0.30 +prompt-toolkit==3.0.31 psutil==5.9.1 ptyprocess==0.7.0 pure-eval==0.2.2 @@ -140,7 +140,7 @@ pygments==2.13.0 pylint==2.15.0 pyparsing==3.0.9 pyrsistent==0.18.1 -pytest==7.1.2 +pytest==7.1.3 pytest-cov==3.0.0 pytest-forked==1.4.0 pytest-profiling==1.7.0 @@ -157,7 +157,7 @@ restructuredtext-lint==1.4.0 rich==12.5.1 send2trash==1.8.0 six==1.16.0 -sniffio==1.2.0 +sniffio==1.3.0 snowballstemmer==2.2.0 soupsieve==2.3.2.post1 sphinx==4.5.0 @@ -179,7 +179,7 @@ sphinxcontrib-qthelp==1.0.3 sphinxcontrib-serializinghtml==1.1.5 sqlalchemy==1.4.40 stack-data==0.5.0 -sympy==1.10.1 +sympy==1.11.1 tabulate==0.8.10 terminado==0.15.0 tinycss2==1.1.1 @@ -188,7 +188,7 @@ tomli==2.0.1 tomlkit==0.11.4 tornado==6.2 tox==3.25.1 -tqdm==4.64.0 +tqdm==4.64.1 traitlets==5.3.0 types-docutils==0.19.0 types-pkg-resources==0.1.3 @@ -202,7 +202,7 @@ wcwidth==0.2.5 webencodings==0.5.1 websocket-client==1.4.0 wheel==0.37.1 -widgetsnbextension==4.0.2 +widgetsnbextension==4.0.3 wrapt==1.14.1 zipp==3.8.1 diff --git a/.cspell.json b/.cspell.json index 68d6a3f85..fbac05684 100644 --- a/.cspell.json +++ b/.cspell.json @@ -90,6 +90,7 @@ "coolwarm", "displaystyle", "dlink", + "docnb", "doctest", "doctests", "dotprint", diff --git a/.github/workflows/ci-docs.yml b/.github/workflows/ci-docs.yml index f92a1f96c..c44cce26f 100644 --- a/.github/workflows/ci-docs.yml +++ b/.github/workflows/ci-docs.yml @@ -32,27 +32,6 @@ jobs: GITHUB_REPO: ${{ github.event.pull_request.head.repo.full_name }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: tox -e docnb - # cspell:ignore docnb - - name: Print error logs with color - if: ${{ failure() }} - # cspell:ignore printf - run: | - for log_file in $(ls docs/_build/html/reports/*); do - for i in $(seq 6); do echo; done - printf '%45s\n' | tr ' ' = - echo "$log_file" - printf '%45s\n' | tr ' ' = - echo - cat "$log_file" - done - for log_file in $(ls /tmp/sphinx-*.log); do - for i in $(seq 6); do echo; done - printf '%45s\n' | tr ' ' = - echo "$log_file" - printf '%45s\n' | tr ' ' = - echo - cat "$log_file" - done - uses: actions/upload-artifact@v3 if: ${{ always() }} with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0588862e3..8d668b1c3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: - id: trailing-whitespace - repo: https://github.com/ComPWA/repo-maintenance - rev: 0.0.143 + rev: 0.0.144 hooks: - id: check-dev-files args: diff --git a/docs/conf.py b/docs/conf.py index 151c4053a..6d13fa356 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -349,6 +349,7 @@ def get_execution_mode() -> str: nb_execution_mode = get_execution_mode() +nb_execution_show_tb = True nb_execution_timeout = -1 nb_output_stderr = "remove" EXECUTE_NB = nb_execution_mode != "off" diff --git a/setup.cfg b/setup.cfg index 6dded33de..d940bc2bb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -45,7 +45,7 @@ install_requires = attrs >=20.1.0 # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen qrules ==0.9.*, >=0.9.6 # https://github.com/ComPWA/qrules/pull/145 singledispatchmethod; python_version <"3.8.0" - sympy >=1.10, <1.11 # module sympy.printing.numpy and array expressions with shape kwarg + sympy >=1.10, <1.12 # module sympy.printing.numpy and array expressions with shape kwarg typing-extensions; python_version <"3.8.0" packages = find: package_dir = diff --git a/src/ampform/dynamics/__init__.py b/src/ampform/dynamics/__init__.py index 1bebbe451..826df6aeb 100644 --- a/src/ampform/dynamics/__init__.py +++ b/src/ampform/dynamics/__init__.py @@ -216,8 +216,8 @@ def evaluate(self) -> sp.Expr: def _latex(self, printer: LatexPrinter, *args) -> str: s = printer._print(self.args[0]) - width = printer._print(self.args[2]) - subscript = _indices_to_subscript(_determine_indices(width)) + gamma0 = self.args[2] + subscript = _indices_to_subscript(_determine_indices(gamma0)) name = Rf"\Gamma{subscript}" if self._name is None else self._name return Rf"{name}\left({s}\right)" diff --git a/src/ampform/dynamics/phasespace.py b/src/ampform/dynamics/phasespace.py index f09c943cb..3dacac8b6 100644 --- a/src/ampform/dynamics/phasespace.py +++ b/src/ampform/dynamics/phasespace.py @@ -88,10 +88,11 @@ def evaluate(self) -> sp.Expr: return (s - (m_a + m_b) ** 2) * (s - (m_a - m_b) ** 2) / (4 * s) # type: ignore[operator] def _latex(self, printer: LatexPrinter, *args) -> str: - s = printer._print(self.args[0]) + s = self.args[0] + s_latex = printer._print(self.args[0]) subscript = _indices_to_subscript(_determine_indices(s)) name = "q^2" + subscript if self._name is None else self._name - return Rf"{name}\left({s}\right)" + return Rf"{name}\left({s_latex}\right)" @implement_doit_method @@ -113,10 +114,11 @@ def evaluate(self) -> sp.Expr: return sp.sqrt(q_squared) / denominator def _latex(self, printer: LatexPrinter, *args) -> str: - s = printer._print(self.args[0]) - subscript = _indices_to_subscript(_determine_indices(s)) + s_symbol = self.args[0] + s_latex = printer._print(s_symbol) + subscript = _indices_to_subscript(_determine_indices(s_symbol)) name = R"\rho" + subscript if self._name is None else self._name - return Rf"{name}\left({s}\right)" + return Rf"{name}\left({s_latex}\right)" @implement_doit_method @@ -143,10 +145,11 @@ def evaluate(self) -> sp.Expr: return sp.sqrt(sp.Abs(q_squared)) / denominator def _latex(self, printer: LatexPrinter, *args) -> str: - s = printer._print(self.args[0]) - subscript = _indices_to_subscript(_determine_indices(s)) + s_symbol = self.args[0] + s_latex = printer._print(s_symbol) + subscript = _indices_to_subscript(_determine_indices(s_symbol)) name = R"\hat{\rho}" + subscript if self._name is None else self._name - return Rf"{name}\left({s}\right)" + return Rf"{name}\left({s_latex}\right)" @implement_doit_method @@ -169,10 +172,11 @@ def evaluate(self) -> sp.Expr: return ComplexSqrt(q_squared) / denominator def _latex(self, printer: LatexPrinter, *args) -> str: - s = printer._print(self.args[0]) - subscript = _indices_to_subscript(_determine_indices(s)) + s_symbol = self.args[0] + s_latex = printer._print(s_symbol) + subscript = _indices_to_subscript(_determine_indices(s_symbol)) name = R"\rho^\mathrm{c}" + subscript if self._name is None else self._name - return Rf"{name}\left({s}\right)" + return Rf"{name}\left({s_latex}\right)" @implement_doit_method @@ -194,10 +198,11 @@ def evaluate(self) -> sp.Expr: return -sp.I * chew_mandelstam def _latex(self, printer: LatexPrinter, *args) -> str: - s = printer._print(self.args[0]) - subscript = _indices_to_subscript(_determine_indices(s)) + s_symbol = self.args[0] + s_latex = printer._print(s_symbol) + subscript = _indices_to_subscript(_determine_indices(s_symbol)) name = R"\rho^\mathrm{CM}" + subscript if self._name is None else self._name - return Rf"{name}\left({s}\right)" + return Rf"{name}\left({s_latex}\right)" def chew_mandelstam_s_wave(s, m_a, m_b): @@ -243,10 +248,11 @@ def evaluate(self) -> sp.Expr: return _analytic_continuation(rho_hat, s, s_threshold) def _latex(self, printer: LatexPrinter, *args) -> str: - s = printer._print(self.args[0]) - subscript = _indices_to_subscript(_determine_indices(s)) + s_symbol = self.args[0] + s_latex = printer._print(s_symbol) + subscript = _indices_to_subscript(_determine_indices(s_symbol)) name = R"\rho^\mathrm{eq}" + subscript if self._name is None else self._name - return Rf"{name}\left({s}\right)" + return Rf"{name}\left({s_latex}\right)" def _analytic_continuation(rho, s, s_threshold) -> sp.Piecewise: diff --git a/src/ampform/helicity/__init__.py b/src/ampform/helicity/__init__.py index bcbb61e2a..70ef3f154 100644 --- a/src/ampform/helicity/__init__.py +++ b/src/ampform/helicity/__init__.py @@ -258,7 +258,7 @@ def sum_components(self, components: Iterable[str]) -> sp.Expr: # noqa: R701 if all(c.startswith("I") for c in components): return sum(self.components[c] for c in components) # type: ignore[return-value] if all(c.startswith("A") for c in components): - return abs(sum(self.components[c] for c in components)) ** 2 + return sp.Abs(sum(self.components[c] for c in components)) ** 2 raise ValueError('Not all component names started with either "A" or "I"') @@ -640,7 +640,7 @@ def __register_amplitudes(self, transition_group: list[StateTransition]) -> None first_transition = transition_group[0] graph_group_label = generate_transition_label(first_transition) component_name = f"I_{{{graph_group_label}}}" - self.__ingredients.components[component_name] = abs(expression) ** 2 + self.__ingredients.components[component_name] = sp.Abs(expression) ** 2 def __formulate_topology_amplitude( self, transitions: Sequence[StateTransition] diff --git a/src/ampform/kinematics/__init__.py b/src/ampform/kinematics/__init__.py index 2f95119f9..00bae164e 100644 --- a/src/ampform/kinematics/__init__.py +++ b/src/ampform/kinematics/__init__.py @@ -705,9 +705,9 @@ def __new__( # pylint: disable=too-many-arguments return create_expression(cls, angle, cos_angle, sin_angle, ones, zeros, **hints) def _latex(self, printer: LatexPrinter, *args) -> str: - angle, *_ = self.args - angle = printer._print(angle) - return Rf"\boldsymbol{{R_y}}\left({angle}\right)" + angle_symbol, *_ = self.args + angle_latex = printer._print(angle_symbol) + return Rf"\boldsymbol{{R_y}}\left({angle_latex}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: printer.module_imports[printer._module].add("array") @@ -778,9 +778,9 @@ def __new__( # pylint: disable=too-many-arguments return create_expression(cls, angle, cos_angle, sin_angle, ones, zeros, **hints) def _latex(self, printer: LatexPrinter, *args) -> str: - angle, *_ = self.args - angle = printer._print(angle) - return Rf"\boldsymbol{{R_z}}\left({angle}\right)" + angle_symbol, *_ = self.args + angle_latex = printer._print(angle_symbol) + return Rf"\boldsymbol{{R_z}}\left({angle_latex}\right)" def _numpycode(self, printer: NumPyPrinter, *args) -> str: printer.module_imports[printer._module].add("array") diff --git a/src/ampform/sympy/__init__.py b/src/ampform/sympy/__init__.py index 78967e0f5..10053b854 100644 --- a/src/ampform/sympy/__init__.py +++ b/src/ampform/sympy/__init__.py @@ -419,17 +419,17 @@ def _render_sum_symbol( ) -> str: if len(values) == 0: return "" - idx = printer._print(idx) + idx_latex = printer._print(idx) if len(values) == 1: value = values[0] - return Rf"\sum_{{{idx}={value}}}" + return Rf"\sum_{{{idx_latex}={value}}}" if _is_regular_series(values): sorted_values = sorted(values, key=float) first_value = sorted_values[0] last_value = sorted_values[-1] - return Rf"\sum_{{{idx}={first_value}}}^{{{last_value}}}" + return Rf"\sum_{{{idx_latex}={first_value}}}^{{{last_value}}}" idx_values = ",".join(map(printer._print, values)) - return Rf"\sum_{{{idx}\in\left\{{{idx_values}\right\}}}}" + return Rf"\sum_{{{idx_latex}\in\left\{{{idx_values}\right\}}}}" def _is_regular_series(values: Sequence[SupportsFloat]) -> bool: diff --git a/src/ampform/sympy/_array_expressions.py b/src/ampform/sympy/_array_expressions.py index dbf87bb34..f22f67a3c 100644 --- a/src/ampform/sympy/_array_expressions.py +++ b/src/ampform/sympy/_array_expressions.py @@ -314,15 +314,13 @@ def _strip_subscript_superscript(symbol: sp.Basic) -> str: @make_commutative class ArrayAxisSum(sp.Expr): - def __new__( - cls, array: ArraySymbol, axis: int | None = None, **hints - ) -> ArrayAxisSum: + def __new__(cls, array: sp.Expr, axis: int | None = None, **hints) -> ArrayAxisSum: if axis is not None and not isinstance(axis, (int, sp.Integer)): raise TypeError("Only single digits allowed for axis") return create_expression(cls, array, axis, **hints) @property - def array(self) -> ArraySymbol: + def array(self) -> sp.Expr: return self.args[0] # type: ignore[return-value] @property