diff --git a/jetstream/analysis.py b/jetstream/analysis.py index 4788403c..add50ef2 100644 --- a/jetstream/analysis.py +++ b/jetstream/analysis.py @@ -36,12 +36,7 @@ from jetstream.metric import Metric from jetstream.platform import PLATFORM_CONFIGS from jetstream.segment import Segment -from jetstream.statistics import ( - Count, - StatisticResult, - StatisticResultCollection, - Summary, -) +from jetstream.statistics import Count, StatisticResult, StatisticResultCollection, Summary from . import bq_normalize_name diff --git a/jetstream/statistics.py b/jetstream/statistics.py index c26586f2..7e07f6e4 100644 --- a/jetstream/statistics.py +++ b/jetstream/statistics.py @@ -16,6 +16,7 @@ import mozanalysis.frequentist_stats.linear_models import mozanalysis.metrics import numpy as np +import pandas as pd from google.cloud import bigquery from metric_config_parser import metric as parser_metric from metric_config_parser.experiment import Experiment @@ -486,7 +487,9 @@ def transform( if ref_values is not None and not ref_values.empty: threshold = df[metric].dropna().quantile(threshold_quantile) - if np.issubdtype(df[metric].dtype, np.integer): + # use pandas' dtype check so nullable extension dtypes (e.g. Int64) from + # BigQuery are handled; np.issubdtype raises on pandas extension dtypes + if pd.api.types.is_integer_dtype(df[metric].dtype): threshold = int(np.ceil(threshold)) post_trim = ref_values.clip(upper=threshold) diff --git a/jetstream/tests/test_config.py b/jetstream/tests/test_config.py index a767426c..f4ab3264 100644 --- a/jetstream/tests/test_config.py +++ b/jetstream/tests/test_config.py @@ -10,11 +10,7 @@ from jetstream.config import ConfigLoader, _ConfigLoader, validate from jetstream.dryrun import DryRunFailedError -from jetstream.platform import ( - Platform, - PlatformConfigurationException, - _generate_platform_config, -) +from jetstream.platform import Platform, PlatformConfigurationException, _generate_platform_config class TestConfig: diff --git a/jetstream/tests/test_experimenter.py b/jetstream/tests/test_experimenter.py index 1d4fe9f0..2405d05b 100644 --- a/jetstream/tests/test_experimenter.py +++ b/jetstream/tests/test_experimenter.py @@ -9,12 +9,7 @@ import pytz from metric_config_parser.experiment import Branch, BucketConfig, Experiment -from jetstream.experimenter import ( - ExperimentCollection, - NimbusExperiment, - Outcome, - Segment, -) +from jetstream.experimenter import ExperimentCollection, NimbusExperiment, Outcome, Segment NIMBUS_EXPERIMENTER_FIXTURE = r""" [ diff --git a/jetstream/tests/test_exposure_signal.py b/jetstream/tests/test_exposure_signal.py index a982d84e..54cbdf82 100644 --- a/jetstream/tests/test_exposure_signal.py +++ b/jetstream/tests/test_exposure_signal.py @@ -4,9 +4,7 @@ import pytest import toml from metric_config_parser.analysis import AnalysisSpec -from metric_config_parser.exposure_signal import ( - ExposureSignal as MetricConfigParserExposureSignal, -) +from metric_config_parser.exposure_signal import ExposureSignal as MetricConfigParserExposureSignal from jetstream.config import ConfigLoader from jetstream.exposure_signal import ExposureSignal diff --git a/jetstream/tests/test_statistics.py b/jetstream/tests/test_statistics.py index a44aa844..91708720 100644 --- a/jetstream/tests/test_statistics.py +++ b/jetstream/tests/test_statistics.py @@ -85,6 +85,25 @@ def test_linear_model_mean(self): assert treatment_result.lower assert treatment_result.upper + def test_linear_model_mean_nullable_integer_dtype(self): + """Metrics read from BigQuery can have pandas nullable Int64 dtype; the integer-dtype + check in transform must not raise 'Cannot interpret Int64Dtype() as a data type'.""" + stat = LinearModelMean() + test_data = pd.DataFrame( + {"branch": ["treatment"] * 10 + ["control"] * 10, "value": list(range(20))} + ) + test_data["value"] = test_data["value"].astype("Int64") + assert pd.api.types.is_integer_dtype(test_data["value"].dtype) + + results = stat.transform( + test_data, "value", "control", None, AnalysisBasis.ENROLLMENTS, "all" + ).root + + branch_results = [r for r in results if r.comparison is None] + treatment_result = next(r for r in branch_results if r.branch == "treatment") + control_result = next(r for r in branch_results if r.branch == "control") + assert treatment_result.point < control_result.point + def test_linear_model_mean_all_zero_reference_branch(self, caplog): """When the reference branch has no non-zero values, log a clear warning.""" stat = LinearModelMean()