Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Removed redundant calculations, unneeded array allocations from jump step #302

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ General
<https://github.com/spacetelescope/stcal/issues/286>`_)
- Improve handling of catalog web service connectivity issues. (`#286
<https://github.com/spacetelescope/stcal/issues/286>`_)

- [jump] Remove redundant calculations and unneeded array allocations.
(`JP-3697<https://jira.stsci.edu/browse/JP-3697>`_)
Comment on lines -79 to +80
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be removed and added to a towncrier change note instead.


1.8.2 (2024-09-10)
==================
Expand Down
59 changes: 20 additions & 39 deletions src/stcal/jump/twopoint_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,6 @@ def find_crs(
return gdq, row_below_gdq, row_above_gdq, 0, dummy
else:
# set 'saturated' or 'do not use' pixels to nan in data
dat[np.where(np.bitwise_and(gdq, sat_flag))] = np.nan
dat[np.where(np.bitwise_and(gdq, dnu_flag))] = np.nan
dat[np.where(np.bitwise_and(gdq, dnu_flag + sat_flag))] = np.nan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance that either the dnu or sat flags might be set without both being set?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bottom line subsumes the previous two, right? As long as dnu_flag and sat_flag are different, bitwise_and will be nonzero any time either of those bits is set, so np.where will pick out all points where either flag is set? I tested this myself to be absolutely sure.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, the syntax
np.where(np.bitwise_and(gdq, flag_a + flag_b + flag_c))
should return all pixels where at least one of flags a, b, and c is set. If it is clearer, we could write something like
dat[np.bitwise_and(gdq, flag_a + flag_b + flag_c) != 0] = np.nan
instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes, you're right. I always struggle with bit-wise math. Thanks!

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! The only hole this opens is if someone is foolish enough to make sat_flag and dnu_flag equal. Hopefully that gets checked somewhere? :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pretty sure that's not possible, but to be super safe, we could cover that case by using | instead of +.


# calculate the differences between adjacent groups (first diffs)
Expand All @@ -183,19 +181,13 @@ def find_crs(

# calc. the median of first_diffs for each pixel along the group axis
first_diffs_masked = np.ma.masked_array(first_diffs, mask=np.isnan(first_diffs))
median_diffs = np.ma.median(first_diffs_masked, axis=(0, 1))
median_diffs = np.nanmedian(first_diffs.reshape((-1, nrows, ncols)), axis=0)
# calculate sigma for each pixel
sigma = np.sqrt(np.abs(median_diffs) + read_noise_2 / nframes)

# reset sigma so pxels with 0 readnoise are not flagged as jumps
sigma[np.where(sigma == 0.)] = np.nan

# compute 'ratio' for each group. this is the value that will be
# compared to 'threshold' to classify jumps. subtract the median of
# first_diffs from first_diffs, take the absolute value and divide by sigma.
e_jump_4d = first_diffs - median_diffs[np.newaxis, :, :]
ratio_all = np.abs(first_diffs - median_diffs[np.newaxis, np.newaxis, :, :]) / \
sigma[np.newaxis, np.newaxis, :, :]
# reset sigma so pixels with 0 readnoise are not flagged as jumps
sigma[sigma == 0.] = np.nan

# Test to see if there are enough groups to use sigma clipping
if (only_use_ints and nints >= minimum_sigclip_groups) or \
(not only_use_ints and total_groups >= minimum_sigclip_groups):
Expand Down Expand Up @@ -231,60 +223,48 @@ def find_crs(
warnings.resetwarnings()
else: # There are not enough groups for sigma clipping

# set 'saturated' or 'do not use' pixels to nan in data
dat[np.where(np.bitwise_and(gdq, sat_flag))] = np.nan
dat[np.where(np.bitwise_and(gdq, dnu_flag))] = np.nan

# calculate the differences between adjacent groups (first diffs)
# use mask on data, so the results will have sat/donotuse groups masked
first_diffs = np.diff(dat, axis=1)

if total_usable_diffs >= min_diffs_single_pass:
warnings.filterwarnings("ignore", ".*All-NaN slice encountered.*", RuntimeWarning)
median_diffs = np.nanmedian(first_diffs, axis=(0, 1))
warnings.resetwarnings()
# calculate sigma for each pixel
sigma = np.sqrt(np.abs(median_diffs) + read_noise_2 / nframes)
# reset sigma so pixels with 0 read noise are not flagged as jumps
sigma[np.where(sigma == 0.)] = np.nan

# compute 'ratio' for each group. this is the value that will be
# compared to 'threshold' to classify jumps. subtract the median of
# first_diffs from first_diffs, take the abs. value and divide by sigma.
e_jump = first_diffs - median_diffs[np.newaxis, np.newaxis, :, :]

ratio = np.abs(e_jump) / sigma[np.newaxis, np.newaxis, :, :]
ratio = (np.abs(first_diffs - median_diffs[np.newaxis, np.newaxis, :]) / sigma[np.newaxis, :]).astype(np.float32)
masked_ratio = np.ma.masked_greater(ratio, normal_rej_thresh)
del ratio
# The jump mask is the ratio greater than the threshold and the difference is usable
jump_mask = np.logical_and(masked_ratio.mask, np.logical_not(first_diffs_masked.mask))
gdq[:, 1:, :, :] = np.bitwise_or(gdq[:, 1:, :, :], jump_mask *
np.uint8(dqflags["JUMP_DET"]))
del masked_ratio

else: # low number of diffs requires iterative flagging
# calculate the differences between adjacent groups (first diffs)
# use mask on data, so the results will have sat/donotuse groups masked
first_diffs = np.abs(np.diff(dat, axis=1))
first_diffs_abs = np.abs(first_diffs)

# calc. the median of first_diffs for each pixel along the group axis
median_diffs = calc_med_first_diffs(first_diffs)
median_diffs_abs = calc_med_first_diffs(first_diffs_abs)

# calculate sigma for each pixel
sigma = np.sqrt(np.abs(median_diffs) + read_noise_2 / nframes)
sigma_abs = np.sqrt(np.abs(median_diffs_abs) + read_noise_2 / nframes)
# reset sigma so pxels with 0 readnoise are not flagged as jumps
sigma[np.where(sigma == 0.0)] = np.nan
sigma_abs[np.where(sigma_abs == 0.0)] = np.nan

# compute 'ratio' for each group. this is the value that will be
# compared to 'threshold' to classify jumps. subtract the median of
# first_diffs from first_diffs, take the abs. value and divide by sigma.
e_jump = first_diffs - median_diffs[np.newaxis, :, :]
ratio = np.abs(e_jump) / sigma[np.newaxis, :, :]
e_jump = first_diffs_abs - median_diffs_abs[np.newaxis, :, :]
ratio = np.abs(e_jump) / sigma_abs[np.newaxis, :, :]

# create a 2d array containing the value of the largest 'ratio' for each pixel
warnings.filterwarnings("ignore", ".*All-NaN slice encountered.*", RuntimeWarning)
max_ratio = np.nanmax(ratio, axis=1)
warnings.resetwarnings()
# now see if the largest ratio of all groups for each pixel exceeds the threshold.
# there are different threshold for 4+, 3, and 2 usable groups
num_unusable_groups = np.sum(np.isnan(first_diffs), axis=(0, 1))
num_unusable_groups = np.sum(np.isnan(first_diffs_abs), axis=(0, 1))
int4cr, row4cr, col4cr = np.where(
np.logical_and(ndiffs - num_unusable_groups >= 4, max_ratio > normal_rej_thresh)
)
Expand All @@ -306,7 +286,7 @@ def find_crs(
# repeat this process until no more CRs are found.
for j in range(len(all_crs_row)):
# get arrays of abs(diffs), ratio, readnoise for this pixel
pix_first_diffs = first_diffs[:, :, all_crs_row[j], all_crs_col[j]]
pix_first_diffs = first_diffs_abs[:, :, all_crs_row[j], all_crs_col[j]]
pix_ratio = ratio[:, :, all_crs_row[j], all_crs_col[j]]
pix_rn2 = read_noise_2[all_crs_row[j], all_crs_col[j]]

Expand Down Expand Up @@ -358,7 +338,8 @@ def find_crs(
num_primary_crs = len(cr_group)
if flag_4_neighbors: # iterate over each 'jump' pixel
for j in range(len(cr_group)):
ratio_this_pix = ratio_all[cr_integ[j], cr_group[j] - 1, cr_row[j], cr_col[j]]
_i, _j, _k, _l = (cr_integ[j], cr_group[j] - 1, cr_row[j], cr_col[j])
ratio_this_pix = np.abs(first_diffs[_i, _j, _k, _l] - median_diffs[_k, _l])/sigma[_k, _l]

# Jumps must be in a certain range to have neighbors flagged
if (ratio_this_pix < max_jump_to_flag_neighbors) and (
Expand Down Expand Up @@ -431,7 +412,8 @@ def find_crs(
group = cr_group[j]
row = cr_row[j]
col = cr_col[j]
if e_jump_4d[intg, group - 1, row, col] >= cthres:
ejump_this_pix = first_diffs[intg, group - 1, row, col] - median_diffs[row, col]
if ejump_this_pix >= cthres:
for kk in range(group, min(group + cgroup + 1, ngroups)):
if (gdq[intg, kk, row, col] & sat_flag) == 0 and (
gdq[intg, kk, row, col] & dnu_flag
Expand All @@ -445,7 +427,6 @@ def find_crs(
dummy = np.zeros((dataa.shape[1] - 1, dataa.shape[2], dataa.shape[3]), dtype=np.float32)
else:
dummy = np.zeros((dataa.shape[2], dataa.shape[3]), dtype=np.float32)

return gdq, row_below_gdq, row_above_gdq, num_primary_crs, dummy


Expand Down
Loading