Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions jetstream/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,7 @@
from jetstream.metric import Metric
from jetstream.platform import PLATFORM_CONFIGS
from jetstream.segment import Segment
from jetstream.statistics import (
Count,
StatisticResult,
StatisticResultCollection,
Summary,
)
from jetstream.statistics import Count, StatisticResult, StatisticResultCollection, Summary

from . import bq_normalize_name

Expand Down
5 changes: 4 additions & 1 deletion jetstream/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import mozanalysis.frequentist_stats.linear_models
import mozanalysis.metrics
import numpy as np
import pandas as pd
from google.cloud import bigquery
from metric_config_parser import metric as parser_metric
from metric_config_parser.experiment import Experiment
Expand Down Expand Up @@ -486,7 +487,9 @@ def transform(

if ref_values is not None and not ref_values.empty:
threshold = df[metric].dropna().quantile(threshold_quantile)
if np.issubdtype(df[metric].dtype, np.integer):
# use pandas' dtype check so nullable extension dtypes (e.g. Int64) from
# BigQuery are handled; np.issubdtype raises on pandas extension dtypes
if pd.api.types.is_integer_dtype(df[metric].dtype):
threshold = int(np.ceil(threshold))
post_trim = ref_values.clip(upper=threshold)

Expand Down
6 changes: 1 addition & 5 deletions jetstream/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@

from jetstream.config import ConfigLoader, _ConfigLoader, validate
from jetstream.dryrun import DryRunFailedError
from jetstream.platform import (
Platform,
PlatformConfigurationException,
_generate_platform_config,
)
from jetstream.platform import Platform, PlatformConfigurationException, _generate_platform_config


class TestConfig:
Expand Down
7 changes: 1 addition & 6 deletions jetstream/tests/test_experimenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
import pytz
from metric_config_parser.experiment import Branch, BucketConfig, Experiment

from jetstream.experimenter import (
ExperimentCollection,
NimbusExperiment,
Outcome,
Segment,
)
from jetstream.experimenter import ExperimentCollection, NimbusExperiment, Outcome, Segment

NIMBUS_EXPERIMENTER_FIXTURE = r"""
[
Expand Down
4 changes: 1 addition & 3 deletions jetstream/tests/test_exposure_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import pytest
import toml
from metric_config_parser.analysis import AnalysisSpec
from metric_config_parser.exposure_signal import (
ExposureSignal as MetricConfigParserExposureSignal,
)
from metric_config_parser.exposure_signal import ExposureSignal as MetricConfigParserExposureSignal

from jetstream.config import ConfigLoader
from jetstream.exposure_signal import ExposureSignal
Expand Down
19 changes: 19 additions & 0 deletions jetstream/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,25 @@ def test_linear_model_mean(self):
assert treatment_result.lower
assert treatment_result.upper

def test_linear_model_mean_nullable_integer_dtype(self):
"""Metrics read from BigQuery can have pandas nullable Int64 dtype; the integer-dtype
check in transform must not raise 'Cannot interpret Int64Dtype() as a data type'."""
stat = LinearModelMean()
test_data = pd.DataFrame(
{"branch": ["treatment"] * 10 + ["control"] * 10, "value": list(range(20))}
)
test_data["value"] = test_data["value"].astype("Int64")
assert pd.api.types.is_integer_dtype(test_data["value"].dtype)

results = stat.transform(
test_data, "value", "control", None, AnalysisBasis.ENROLLMENTS, "all"
).root

branch_results = [r for r in results if r.comparison is None]
treatment_result = next(r for r in branch_results if r.branch == "treatment")
control_result = next(r for r in branch_results if r.branch == "control")
assert treatment_result.point < control_result.point

def test_linear_model_mean_all_zero_reference_branch(self, caplog):
"""When the reference branch has no non-zero values, log a clear warning."""
stat = LinearModelMean()
Expand Down
Loading