diff --git a/galpy/potential/WrapperPotential.py b/galpy/potential/WrapperPotential.py index 0c474a959..4dee721d9 100644 --- a/galpy/potential/WrapperPotential.py +++ b/galpy/potential/WrapperPotential.py @@ -10,6 +10,7 @@ planarPotential, ) from .Potential import ( + Force, Potential, _dim, _evaluatephitorques, @@ -90,6 +91,19 @@ def __init__(self, amp=1.0, pot=None, ro=None, vo=None, _init=None, **kwargs): return None # Don't run __init__ at the end of setup Potential.__init__(self, amp=amp, ro=ro, vo=vo) self._pot = pot + # Check that we are not wrapping a non-potential Force object + if ( + isinstance(self._pot, list) + and any( + [ + isinstance(p, Force) and not isinstance(p, Potential) + for p in self._pot + ] + ) + ) or (isinstance(self._pot, Force) and not isinstance(self._pot, Potential)): + raise RuntimeError( + "WrapperPotential cannot currently wrap non-Potential Force objects" + ) self.isNonAxi = _isNonAxi(self._pot) # Check whether units are consistent between the wrapper and the # wrapped potential diff --git a/tests/test_potential.py b/tests/test_potential.py index ac09e8292..d019e6e38 100644 --- a/tests/test_potential.py +++ b/tests/test_potential.py @@ -5701,6 +5701,31 @@ def test_Wrapper_incompatibleunitserror(): return None +def test_Wrapper_Force_error(): + # Test that applying a wrapper to a DissipativeForce does not currently work + def M(t): + return 1.0 + + # Initialize potentials and time-varying potentials + df = potential.ChandrasekharDynamicalFrictionForce(GMs=1.0) + with pytest.raises(RuntimeError) as excinfo: + df_wrap = potential.TimeDependentAmplitudeWrapperPotential(A=M, amp=1, pot=df) + assert ( + "WrapperPotential cannot currently wrap non-Potential Force objects" + == excinfo.value.args[0] + ) + # Also test for list + with pytest.raises(RuntimeError) as excinfo: + df_wrap = potential.TimeDependentAmplitudeWrapperPotential( + A=M, amp=1, pot=potential.MWPotential2014 + df + ) + assert ( + "WrapperPotential cannot currently wrap non-Potential Force objects" + == excinfo.value.args[0] + ) + return None + + def test_WrapperPotential_unittransfer_3d(): # Test that units are properly transferred between a potential and its # wrapper