diff --git a/jetstream/analysis.py b/jetstream/analysis.py index 985fdaaf..4788403c 100644 --- a/jetstream/analysis.py +++ b/jetstream/analysis.py @@ -831,6 +831,39 @@ def _create_subset_metric_table_query_univariate( return query + def _covariate_table_metric_name( + self, + during_metric: Metric, + covariate_metric_name: str, + covariate_period: AnalysisPeriod, + discrete_metrics: bool, + ) -> str | None: + """Resolves the ``metric`` argument for the preenrollment covariate table name. + + Discrete metric tables are partitioned by data source (see calculate_metric_for_ds), + so the covariate table is named using the data source name of the covariate metric -- + the table that actually contains the covariate metric's column. For non-discrete + analyses tables are not partitioned by data source, so no metric component is used. + """ + if not discrete_metrics: + return None + + # the covariate may be configured as a different metric than the during-experiment + # metric; look it up in the covariate period to use its own data source + if covariate_metric_name != during_metric.name: + covariate_metric = next( + ( + summary.metric + for summary in self.config.metrics.get(covariate_period, []) + if summary.metric.name == covariate_metric_name + ), + None, + ) + if covariate_metric is not None: + return covariate_metric.data_source.name + + return during_metric.data_source.name + def _create_subset_metric_table_query_covariate( self, metric_table_name: str, @@ -850,7 +883,9 @@ def _create_subset_metric_table_query_covariate( "metrics with dependencies are not currently supported for covariate adjustment" ) - metric_name = metric.name if discrete_metrics else None + metric_name = self._covariate_table_metric_name( + metric, covariate_metric_name, covariate_period, discrete_metrics + ) covariate_table_name = self._table_name( covariate_period.value, 1, analysis_basis=AnalysisBasis.ENROLLMENTS, metric=metric_name ) @@ -933,7 +968,10 @@ def _subset_metric_table_prerequisites( if covariate_params := summary.statistic.params.get("covariate_adjustment", False): # type: ignore[attr-defined] covariate_period = AnalysisPeriod(covariate_params["period"]) if covariate_period != period and period not in PREENROLLMENT_PERIODS: - metric_name = summary.metric.name if discrete_metrics else None + covariate_metric_name = covariate_params.get("metric", summary.metric.name) + metric_name = self._covariate_table_metric_name( + summary.metric, covariate_metric_name, covariate_period, discrete_metrics + ) prereqs.add( self._table_name( covariate_period.value, 1, AnalysisBasis.ENROLLMENTS, metric=metric_name diff --git a/jetstream/tests/test_analysis.py b/jetstream/tests/test_analysis.py index 5418fa91..05ecad79 100644 --- a/jetstream/tests/test_analysis.py +++ b/jetstream/tests/test_analysis.py @@ -1577,6 +1577,85 @@ def test_subset_metric_table_prerequisites_covariate(experiments): assert prereqs == {expected_covariate_table} +def test_subset_metric_table_prerequisites_covariate_discrete_uses_data_source(experiments): + """For discrete metrics, the preenrollment covariate table is partitioned by data source, + so the prerequisite table name must use the data source name, not the metric name.""" + metric = Metric( + name="metric_name", + data_source=DataSource(name="test_data_source", from_expression="test.test"), + select_expression="test", + analysis_bases=[AnalysisBasis.ENROLLMENTS], + ) + summary = MagicMock() + summary.statistic.params = { + "covariate_adjustment": {"metric": "metric_name", "period": "preenrollment_week"} + } + summary.metric = metric + + analysis = _empty_analysis(experiments) + prereqs = analysis._subset_metric_table_prerequisites( + summary, + "normandy_test_slug_enrollments_test_data_source_week_1", + AnalysisBasis.ENROLLMENTS, + AnalysisPeriod.WEEK, + True, + ) + + # table is named by data source, matching how calculate_metric_for_ds writes it + expected_covariate_table = analysis._table_name( + AnalysisPeriod.PREENROLLMENT_WEEK.value, + 1, + AnalysisBasis.ENROLLMENTS, + metric="test_data_source", + ) + assert prereqs == {expected_covariate_table} + # the (buggy) metric-name-based table should NOT be referenced + wrong_table = analysis._table_name( + AnalysisPeriod.PREENROLLMENT_WEEK.value, + 1, + AnalysisBasis.ENROLLMENTS, + metric="metric_name", + ) + assert wrong_table not in prereqs + + +def test_covariate_table_metric_name_resolves_covariate_data_source(experiments): + """When the covariate is a different metric than the during-experiment metric, the table + name must use the covariate metric's own data source.""" + from jetstream.statistics import Summary as JetstreamSummary + + during_metric = Metric( + name="during_metric", + data_source=DataSource(name="during_ds", from_expression="test.test"), + select_expression="test", + analysis_bases=[AnalysisBasis.ENROLLMENTS], + ) + covariate_metric = Metric( + name="covariate_metric", + data_source=DataSource(name="covariate_ds", from_expression="test.test"), + select_expression="test", + analysis_bases=[AnalysisBasis.ENROLLMENTS], + ) + + analysis = _empty_analysis(experiments) + analysis.config.metrics = { + AnalysisPeriod.PREENROLLMENT_WEEK: [JetstreamSummary(covariate_metric, MagicMock(), [])], + } + + resolved = analysis._covariate_table_metric_name( + during_metric, "covariate_metric", AnalysisPeriod.PREENROLLMENT_WEEK, True + ) + assert resolved == "covariate_ds" + + # non-discrete analyses are not partitioned by data source -> no metric component + assert ( + analysis._covariate_table_metric_name( + during_metric, "covariate_metric", AnalysisPeriod.PREENROLLMENT_WEEK, False + ) + is None + ) + + def test_subset_metric_table_prerequisites_covariate_skipped_for_preenrollment_period(experiments): """When the current period is a preenrollment period, covariate adjustment is not applied, so no extra prerequisite table should be returned."""