Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions jetstream/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,39 @@ def _create_subset_metric_table_query_univariate(

return query

def _covariate_table_metric_name(
self,
during_metric: Metric,
covariate_metric_name: str,
covariate_period: AnalysisPeriod,
discrete_metrics: bool,
) -> str | None:
"""Resolves the ``metric`` argument for the preenrollment covariate table name.

Discrete metric tables are partitioned by data source (see calculate_metric_for_ds),
so the covariate table is named using the data source name of the covariate metric --
the table that actually contains the covariate metric's column. For non-discrete
analyses tables are not partitioned by data source, so no metric component is used.
"""
if not discrete_metrics:
return None

# the covariate may be configured as a different metric than the during-experiment
# metric; look it up in the covariate period to use its own data source
if covariate_metric_name != during_metric.name:
covariate_metric = next(
(
summary.metric
for summary in self.config.metrics.get(covariate_period, [])
if summary.metric.name == covariate_metric_name
),
None,
)
if covariate_metric is not None:
return covariate_metric.data_source.name

return during_metric.data_source.name

def _create_subset_metric_table_query_covariate(
self,
metric_table_name: str,
Expand All @@ -850,7 +883,9 @@ def _create_subset_metric_table_query_covariate(
"metrics with dependencies are not currently supported for covariate adjustment"
)

metric_name = metric.name if discrete_metrics else None
metric_name = self._covariate_table_metric_name(
metric, covariate_metric_name, covariate_period, discrete_metrics
)
covariate_table_name = self._table_name(
covariate_period.value, 1, analysis_basis=AnalysisBasis.ENROLLMENTS, metric=metric_name
)
Expand Down Expand Up @@ -933,7 +968,10 @@ def _subset_metric_table_prerequisites(
if covariate_params := summary.statistic.params.get("covariate_adjustment", False): # type: ignore[attr-defined]
covariate_period = AnalysisPeriod(covariate_params["period"])
if covariate_period != period and period not in PREENROLLMENT_PERIODS:
metric_name = summary.metric.name if discrete_metrics else None
covariate_metric_name = covariate_params.get("metric", summary.metric.name)
metric_name = self._covariate_table_metric_name(
summary.metric, covariate_metric_name, covariate_period, discrete_metrics
)
prereqs.add(
self._table_name(
covariate_period.value, 1, AnalysisBasis.ENROLLMENTS, metric=metric_name
Expand Down
79 changes: 79 additions & 0 deletions jetstream/tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1577,6 +1577,85 @@ def test_subset_metric_table_prerequisites_covariate(experiments):
assert prereqs == {expected_covariate_table}


def test_subset_metric_table_prerequisites_covariate_discrete_uses_data_source(experiments):
"""For discrete metrics, the preenrollment covariate table is partitioned by data source,
so the prerequisite table name must use the data source name, not the metric name."""
metric = Metric(
name="metric_name",
data_source=DataSource(name="test_data_source", from_expression="test.test"),
select_expression="test",
analysis_bases=[AnalysisBasis.ENROLLMENTS],
)
summary = MagicMock()
summary.statistic.params = {
"covariate_adjustment": {"metric": "metric_name", "period": "preenrollment_week"}
}
summary.metric = metric

analysis = _empty_analysis(experiments)
prereqs = analysis._subset_metric_table_prerequisites(
summary,
"normandy_test_slug_enrollments_test_data_source_week_1",
AnalysisBasis.ENROLLMENTS,
AnalysisPeriod.WEEK,
True,
)

# table is named by data source, matching how calculate_metric_for_ds writes it
expected_covariate_table = analysis._table_name(
AnalysisPeriod.PREENROLLMENT_WEEK.value,
1,
AnalysisBasis.ENROLLMENTS,
metric="test_data_source",
)
assert prereqs == {expected_covariate_table}
# the (buggy) metric-name-based table should NOT be referenced
wrong_table = analysis._table_name(
AnalysisPeriod.PREENROLLMENT_WEEK.value,
1,
AnalysisBasis.ENROLLMENTS,
metric="metric_name",
)
assert wrong_table not in prereqs


def test_covariate_table_metric_name_resolves_covariate_data_source(experiments):
"""When the covariate is a different metric than the during-experiment metric, the table
name must use the covariate metric's own data source."""
from jetstream.statistics import Summary as JetstreamSummary

during_metric = Metric(
name="during_metric",
data_source=DataSource(name="during_ds", from_expression="test.test"),
select_expression="test",
analysis_bases=[AnalysisBasis.ENROLLMENTS],
)
covariate_metric = Metric(
name="covariate_metric",
data_source=DataSource(name="covariate_ds", from_expression="test.test"),
select_expression="test",
analysis_bases=[AnalysisBasis.ENROLLMENTS],
)

analysis = _empty_analysis(experiments)
analysis.config.metrics = {
AnalysisPeriod.PREENROLLMENT_WEEK: [JetstreamSummary(covariate_metric, MagicMock(), [])],
}

resolved = analysis._covariate_table_metric_name(
during_metric, "covariate_metric", AnalysisPeriod.PREENROLLMENT_WEEK, True
)
assert resolved == "covariate_ds"

# non-discrete analyses are not partitioned by data source -> no metric component
assert (
analysis._covariate_table_metric_name(
during_metric, "covariate_metric", AnalysisPeriod.PREENROLLMENT_WEEK, False
)
is None
)


def test_subset_metric_table_prerequisites_covariate_skipped_for_preenrollment_period(experiments):
"""When the current period is a preenrollment period, covariate adjustment is not applied,
so no extra prerequisite table should be returned."""
Expand Down
Loading