diff --git a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py index c8bdda9b..0c2d7ca2 100644 --- a/python/lsst/pipe/base/tests/mocks/_pipeline_task.py +++ b/python/lsst/pipe/base/tests/mocks/_pipeline_task.py @@ -44,13 +44,14 @@ from typing import TYPE_CHECKING, Any, ClassVar, TypeVar from astropy.units import Quantity -from lsst.daf.butler import DataCoordinate, DatasetRef, DeferredDatasetHandle, SerializedDatasetType +from lsst.daf.butler import DataCoordinate, DatasetRef, DeferredDatasetHandle, Quantum, SerializedDatasetType from lsst.pex.config import Config, ConfigDictField, ConfigurableField, Field, ListField from lsst.utils.doImport import doImportType from lsst.utils.introspection import get_full_type_name from lsst.utils.iteration import ensure_iterable from ... import connectionTypes as cT +from ..._status import AnnotatedPartialOutputsError, RepeatableQuantumError from ...config import PipelineTaskConfig from ...connections import InputQuantizedConnection, OutputQuantizedConnection, PipelineTaskConnections from ...pipeline_graph import PipelineGraph @@ -291,7 +292,7 @@ def runQuantum( # Possibly raise an exception. if self.data_id_match is not None and self.data_id_match.match(quantum.dataId): assert self.fail_exception is not None, "Exception type must be defined" - message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" + if self.memory_required is not None: if butlerQC.resources.max_mem < self.memory_required: _LOG.info( @@ -299,10 +300,10 @@ def runQuantum( self.getName(), quantum.dataId, ) - raise self.fail_exception(message) + self._fail(quantum) else: _LOG.info("Simulating failure of task '%s' on quantum %s", self.getName(), quantum.dataId) - raise self.fail_exception(message) + self._fail(quantum) # Populate the bit of provenance we store in all outputs. _LOG.info("Reading input data for task '%s' on quantum %s", self.getName(), quantum.dataId) @@ -351,6 +352,25 @@ def runQuantum( _LOG.info("Finished mocking task '%s' on quantum %s", self.getName(), quantum.dataId) + def _fail(self, quantum: Quantum) -> None: + """Raise the configured exception. + + Parameters + ---------- + quantum : `lsst.daf.butler.Quantum` + Quantum producing the error. + """ + message = f"Simulated failure: task={self.getName()} dataId={quantum.dataId}" + if self.fail_exception is AnnotatedPartialOutputsError: + # This exception is expected to always chain another. + try: + raise RepeatableQuantumError(message) + except RepeatableQuantumError as err: + raise AnnotatedPartialOutputsError() from err + else: + assert self.fail_exception is not None, "Method should not be called." + raise self.fail_exception(message) + class MockPipelineDefaultTargetConnections(PipelineTaskConnections, dimensions=()): pass