Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 11 additions & 26 deletions imap_processing/lo/l1b/lo_l1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,23 +636,12 @@ def set_spin_cycle_from_spin_data(
science_met_per_asc, spin_met_per_asc
)

# Add a flag for invalid ASCs
valid_mask = find_valid_asc(science_to_spin_indices, spin_data)
l1b_science["incomplete_asc"] = xr.DataArray(~valid_mask, dims=["epoch"])

# If none valid, return an empty/filtered dataset
# (preserves dims & avoids misalignment)
if not valid_mask.any():
logger.warning(
"No valid ASCs remain after filtering; returning empty epoch set"
)
return l1b_science.isel(epoch=[])

# Filter the input datasets to only the valid ASCs so all subsequent arrays align
l1a_valid = l1a_science.isel(epoch=valid_mask)
l1b_valid = l1b_science.isel(epoch=valid_mask)

# Use the valid closest indices to get the corresponding acq_start rows
science_to_spin_indices_valid = science_to_spin_indices[valid_mask]
closest_start_acq_per_asc = acq_start.isel(epoch=science_to_spin_indices_valid)
# Use the closest indices to get the corresponding acq_start rows
closest_start_acq_per_asc = acq_start.isel(epoch=science_to_spin_indices)

# compute spin start number for each remaining ASC
spin_start_num_per_asc = np.atleast_1d(get_spin_number(closest_start_acq_per_asc))
Expand All @@ -661,29 +650,29 @@ def set_spin_cycle_from_spin_data(
logical_src = l1a_science.attrs.get("Logical_source", "")
if logical_src == "imap_lo_l1a_de":
# For DE: expand per-event across ESA steps within each (valid) ASC
counts = l1a_valid["de_count"].values
counts = l1a_science["de_count"].values
spin_cycle = []
for asc_idx, _count in enumerate(counts):
esa_steps = l1a_valid["esa_step"].values[
esa_steps = l1a_science["esa_step"].values[
sum(counts[:asc_idx]) : sum(counts[: asc_idx + 1])
]
spin_cycle.extend(
spin_start_num_per_asc[asc_idx, 0] + 7 + (esa_steps - 1) * 2
)
spin_cycle = np.array(spin_cycle)
l1b_valid["spin_cycle"] = xr.DataArray(spin_cycle, dims=["epoch"])
l1b_science["spin_cycle"] = xr.DataArray(spin_cycle, dims=["epoch"])
elif logical_src == "imap_lo_l1a_histogram":
# For histogram: keep 2D array (n_valid_epochs, esa_step)
esa_steps = l1b_valid["esa_step"].values # shape: (7,)
esa_steps = l1b_science["esa_step"].values # shape: (7,)
spin_cycle = spin_start_num_per_asc + 7 + (esa_steps - 1) * 2
l1b_valid["spin_cycle"] = xr.DataArray(spin_cycle, dims=["epoch", "esa_step"])
l1b_science["spin_cycle"] = xr.DataArray(spin_cycle, dims=["epoch", "esa_step"])
else:
raise ValueError(
"set spin cycle called with unsupported dataset with "
"Logical_source: {logical_src}"
)

return l1b_valid
return l1b_science


def match_science_to_spin_asc(
Expand Down Expand Up @@ -744,13 +733,9 @@ def find_valid_asc(
valid_indices = _check_valid_indices(science_to_spin_indices)
valid_spin_count = _check_sufficient_spins(spin_data)[science_to_spin_indices]

# Combine only these two masks:
# Combine these two masks
valid_mask = valid_indices & valid_spin_count

total_invalid = (~valid_mask).sum()
if total_invalid > 0:
logger.info(f"Dropping {total_invalid} invalid ASCs total")

return valid_mask


Expand Down
16 changes: 8 additions & 8 deletions imap_processing/tests/lo/test_lo_l1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,18 +1386,18 @@ def test_set_spin_cycle_from_spin_data_insufficient_spins():

# Act
with patch(
"imap_processing.lo.l1b.lo_l1b.get_spin_number", return_value=np.array([28])
"imap_processing.lo.l1b.lo_l1b.get_spin_number",
return_value=np.array([28, 26, 24]),
):
result = set_spin_cycle_from_spin_data(l1a_hist, l1b_hist, spin_data)

# Assert - Only epoch 1 (science_met[1]=200) should remain
# (matched to spin with 28 spins)
assert len(result["epoch"]) == 1
expected_epochs = met_to_ttj2000ns([200])
np.testing.assert_array_equal(result["epoch"].values, expected_epochs)
assert len(result["epoch"]) == 3
np.testing.assert_array_equal(result["epoch"].values, epoch_date)

# Verify spin_cycle shape matches filtered data
assert result["spin_cycle"].shape == (1, 7)
# Verify spin_cycle shape has all valid ESA steps and all epochs
assert result["spin_cycle"].shape == (3, 7)
# We should have added a flag about an incomplete ASC
np.testing.assert_array_equal(result["incomplete_asc"], [True, False, True])


@patch(
Expand Down