Skip to content

Commit

Permalink
JP-3618: Add intermediate LRS slit wcs reference frame (spacetelescop…
Browse files Browse the repository at this point in the history
…e#8475)

Co-authored-by: Howard Bushouse <[email protected]>
  • Loading branch information
drlaw1558 and hbushouse authored Jun 21, 2024
1 parent 5a89d78 commit 33e3e67
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 30 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ assign_wcs
- Update default parameters to increase the accuracy of the SIP approximation
in the output FITS WCS. [#8529]

- Update MIRI LRS WCS code to introduce an intermediate alpha-beta slit reference frame
between pixel coordinates and the v2/v3 frame. [#8475]

- Added handling for fixed slit sources defined in a MSA metadata file, for combined
NIRSpec MOS and fixed slit observations. Slits are now appended to the data
product in the order they appear in the MSA file. [#8467]
Expand Down
134 changes: 104 additions & 30 deletions jwst/assign_wcs/miri.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,15 @@ def lrs(input_model, reference_files):
# Define the various coordinate frames.
# Original detector frame
detector = cf.Frame2D(name='detector', axes_order=(0, 1), unit=(u.pix, u.pix))

# Intermediate slit frame
alpha_beta = cf.Frame2D(name='alpha_beta_spatial', axes_order=(0, 1),
unit=(u.arcsec, u.arcsec), axes_names=('alpha', 'beta'))
spec_local = cf.SpectralFrame(name='alpha_beta_spectral', axes_order=(2,),
unit=(u.micron,), axes_names=('lambda',))
miri_focal = cf.CompositeFrame([alpha_beta, spec_local], name='alpha_beta')


# Spectral component
spec = cf.SpectralFrame(name='spec', axes_order=(2,), unit=(u.micron,), axes_names=('lambda',))
# v2v3 spatial component
Expand All @@ -193,7 +202,8 @@ def lrs(input_model, reference_files):
world = cf.CompositeFrame(name="world", frames=[icrs, spec])

# Create the transforms
dettotel = lrs_distortion(input_model, reference_files)
dettoabl = lrs_xytoabl(input_model, reference_files)
abltov2v3l = lrs_abltov2v3l(input_model, reference_files)
v2v3tosky = pointing.v23tosky(input_model)
teltosky = v2v3tosky & models.Identity(1)

Expand All @@ -205,19 +215,20 @@ def lrs(input_model, reference_files):
) & models.Identity(1)

# Put the transforms together into a single pipeline
pipeline = [(detector, dettotel),
pipeline = [(detector, dettoabl),
(miri_focal, abltov2v3l),
(v2v3, va_corr),
(v2v3vacorr, teltosky),
(world, None)]

return pipeline


def lrs_distortion(input_model, reference_files):
def lrs_xytoabl(input_model, reference_files):
"""
The LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline.
The first part of LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline.
Transform from subarray (x, y) to (v2, v3, lambda) using
Transform from subarray (x, y) to (alpha, beta, lambda) using
the "specwcs" and "distortion" reference files.
"""
Expand Down Expand Up @@ -249,6 +260,11 @@ def lrs_distortion(input_model, reference_files):
# Transform to slitless subarray from full array
zero_point = subarray2full.inverse(zero_point[0], zero_point[1])

# Figure out the typical along-slice pixel scale at the center of the slit
v2_cen, v3_cen = subarray_dist(zero_point[0], zero_point[1])
v2_off, v3_off = subarray_dist(zero_point[0] + 1, zero_point[1])
pscale = np.sqrt(np.power(v2_cen - v2_off, 2) + np.power(v3_cen - v3_off,2))

# In the lrsdata reference table, X_center,y_center,wavelength describe the location of the
# centroid trace along the detector in pixels relative to nominal location.
# x0,y0(ul) x1,y1 (ur) x2,y2(lr) x3,y3(ll) define corners of the box within which the distortion
Expand All @@ -275,18 +291,6 @@ def lrs_distortion(input_model, reference_files):
bb_sub = ((input_model.meta.subarray.xstart - 1 + 4 - 0.5, input_model.meta.subarray.xsize - 1 + 0.5),
(np.floor(y2.min() + zero_point[1]) - 0.5, np.ceil(y0.max() + zero_point[1]) + 0.5))

# Find the ROW of the zero point
row_zero_point = zero_point[1]

# The inputs to the "detector_to_v2v3" transform are
# - the indices in x spanning the entire image row
# - y is the y-value of the zero point
# This is equivalent of making a vector of x, y locations for
# every pixel in the reference row
const1d = models.Const1D(row_zero_point)
const1d.inverse = models.Const1D(row_zero_point)
det_to_v2v3 = models.Identity(1) & const1d | subarray_dist

# Now deal with the fact that the spectral trace isn't perfectly up and down along detector.
# This information is contained in the xcenter/ycenter values in the CDP table, but we'll handle it
# as a simple x shift using a linear fit to this relation provided by the CDP.
Expand Down Expand Up @@ -330,34 +334,104 @@ def lrs_distortion(input_model, reference_files):
ymodel = models.Mapping([1], n_inputs=2)
# What is the effective XY as a function of subarray x,y?
xymodel = models.Mapping((0, 1, 0, 1)) | xmodel & ymodel
# What is the alpha as a function of slit XY?
alphamodel = models.Mapping([0], n_inputs=2) | models.Shift(-zero_point[0]) | models.Polynomial1D(1, c0=0, c1=pscale)
# What is the alpha,beta as a function of slit XY? (beta is always zero)
abmodel = models.Mapping((0, 1, 0)) | alphamodel & models.Const1D(0)

# Define a shift by the reference point and immediately back again
# This doesn't do anything effectively, but it stores the reference point for later use in pathloss
reftransform = models.Shift(-zero_point[0]) & models.Shift(-zero_point[1]) | models.Shift(+zero_point[0]) & models.Shift(+zero_point[1])
# Put the transforms together
xytov2v3 = reftransform | xymodel | det_to_v2v3
xytoab = reftransform | xymodel | abmodel

# Construct the full distortion model (xsub,ysub -> v2,v3,wavelength)
# Construct the full distortion model (xsub,ysub -> alpha,beta,wavelength)
lrs_wav_model = models.Mapping([1], n_inputs=2) | wavemodel
dettotel = models.Mapping((0, 1, 0, 1)) | xytov2v3 & lrs_wav_model
dettoabl = models.Mapping((0, 1, 0, 1)) | xytoab & lrs_wav_model

# Construct the inverse distortion model (v2,v3,wavelength -> xsub,ysub)
# Go from v2,v3 to slit-x
v2v3_to_xdet = det_to_v2v3.inverse | models.Mapping([0], n_inputs=2)
# Construct the inverse distortion model (alpha,beta,wavelength -> xsub,ysub)
# Go from alpha to slit-X
slitxmodel = models.Polynomial1D(1, c0=0, c1=1/pscale) | models.Shift(zero_point[0])
# Go from lambda to real y
lam_to_y = wavemodel.inverse
# Go from slit-x and real y to real-x
backwards = models.Mapping([0], n_inputs=2) + (models.Mapping([1], n_inputs=2) | dxmodel)
# Go from v2,v3,lam to real x
aa = v2v3_to_xdet & lam_to_y | backwards
# Go from v2,v3,lam to real y

# Go from alpha,beta,lam to real x
aa = models.Mapping((0, 2)) | slitxmodel & lam_to_y | backwards
# Go from alpha,beta,lam to real y
bb = models.Mapping([2], n_inputs=3) | lam_to_y
# Go from v2,v3,lam, to real x,y
dettotel.inverse = models.Mapping((0, 1, 2, 0, 1, 2)) | aa & bb
# Go from alpha,beta,lam, to real x,y
dettoabl.inverse = models.Mapping((0, 1, 2, 0, 1, 2)) | aa & bb

# Bounding box is the subarray bounding box, because we're assuming subarray coordinates passed in
dettotel.bounding_box = bb_sub[::-1]
dettoabl.bounding_box = bb_sub[::-1]

return dettoabl

def lrs_abltov2v3l(input_model, reference_files):
"""
The second part of LRS-FIXEDSLIT and LRS-SLITLESS WCS pipeline.
Transform from (alpha, beta, lambda) to (v2, v3, lambda) using
the "specwcs" and "distortion" reference files.
"""

# subarray to full array transform
subarray2full = subarray_transform(input_model)

# full array to v2v3 transform for the ordinary imager
with DistortionModel(reference_files['distortion']) as dist:
distortion = dist.model

# Combine models to create subarray to v2v3 distortion
if subarray2full is not None:
subarray_dist = subarray2full | distortion
else:
subarray_dist = distortion

ref = fits.open(reference_files['specwcs'])

with ref:
# Get the zero point from the reference data.
# The zero_point is X, Y (which should be COLUMN, ROW)
# These are 1-indexed in CDP-7 (i.e., SIAF convention) so must be converted to 0-indexed
if input_model.meta.exposure.type.lower() == 'mir_lrs-fixedslit':
zero_point = ref[0].header['imx'] - 1, ref[0].header['imy'] - 1
elif input_model.meta.exposure.type.lower() == 'mir_lrs-slitless':
zero_point = ref[0].header['imxsltl'] - 1, ref[0].header['imysltl'] - 1
# Transform to slitless subarray from full array
zero_point = subarray2full.inverse(zero_point[0], zero_point[1])

return dettotel
# Figure out the typical along-slice pixel scale at the center of the slit
v2_cen, v3_cen = subarray_dist(zero_point[0], zero_point[1])
v2_off, v3_off = subarray_dist(zero_point[0] + 1, zero_point[1])
pscale = np.sqrt(np.power(v2_cen - v2_off, 2) + np.power(v3_cen - v3_off,2))

# Go from alpha to slit-X
slitxmodel = models.Polynomial1D(1, c0=0, c1=1 / pscale) | models.Shift(zero_point[0])
# Go from beta to slit-Y (row_zero_point plus some offset)
# Beta should always be zero unless using in a pseudo-ifu mode
slitymodel = models.Polynomial1D(1, c0=0, c1=1 / pscale) | models.Shift(zero_point[1])
# Go from alpha-beta to slit xy, and onward to v2v3
ab_to_v2v3 = slitxmodel & slitymodel | subarray_dist
# Put it together to pass through wavelength
abl_to_v2v3l = models.Mapping((0, 1, 2)) | ab_to_v2v3 & models.Identity(1)

# Define the inverse transform
# Go from slit X to alpha
alphamodel = models.Shift(-zero_point[0]) | models.Polynomial1D(1, c0=0, c1=pscale)
# Go from slit Y to beta
betamodel = models.Shift(-zero_point[1]) | models.Polynomial1D(1, c0=0, c1=pscale)
# Go from v2,v3 to slit-x,slit-y
v2v3_to_xydet = subarray_dist.inverse
# Go from v2,v3 to alpha, beta
aa = v2v3_to_xydet | alphamodel & betamodel
# Go from v2,v3,lambda to alpha,beta,lambda
abl_to_v2v3l.inverse = models.Mapping((0,1,2)) | aa & models.Identity(1)

return abl_to_v2v3l

def ifu(input_model, reference_files):
"""
Expand Down

0 comments on commit 33e3e67

Please sign in to comment.