From 2199a5c17387be6c72a6c5ed81b4f92cfbee6583 Mon Sep 17 00:00:00 2001 From: tmcclintock Date: Fri, 9 May 2025 23:38:59 -0700 Subject: [PATCH 1/3] Fix: type checking --- src/ConditionalGMM/UnivariateGMM.py | 11 +++++------ src/ConditionalGMM/condGMM.py | 4 ++-- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/src/ConditionalGMM/UnivariateGMM.py b/src/ConditionalGMM/UnivariateGMM.py index 52fb14e..ddc4cc3 100644 --- a/src/ConditionalGMM/UnivariateGMM.py +++ b/src/ConditionalGMM/UnivariateGMM.py @@ -1,7 +1,8 @@ """Helpful functions that can only be computed (easily) for univarate GMMs.""" import numpy as np -import scipy as sp +from scipy.optimize import newton +from scipy.stats import norm class UniGMM: @@ -49,7 +50,7 @@ def pdf(self, x): assert np.ndim(x) < 2 # TODO vectorize pdfs = np.array( - [sp.stats.norm.pdf(x, mi, vi) for mi, vi in zip(self.means, self.variances)] + [norm.pdf(x, mi, vi) for mi, vi in zip(self.means, self.variances)] ) return np.dot(self.weights, pdfs) @@ -76,7 +77,7 @@ def cdf(self, x): """ cdfs = np.array( - [sp.stats.norm.cdf(x, mi, vi) for mi, vi in zip(self.means, self.variances)] + [norm.cdf(x, mi, vi) for mi, vi in zip(self.means, self.variances)] ) return np.dot(self.weights, cdfs) @@ -103,9 +104,7 @@ def ppf(self, q): (float or array-like) quantile corresponding to `q` """ - return sp.optimize.newton( - func=lambda x: self.cdf(x) - q, x0=self.mean(), fprime=self.pdf - ) + return newton(func=lambda x: self.cdf(x) - q, x0=self.mean(), fprime=self.pdf) def mean(self): """Mean of the RV for the GMM. diff --git a/src/ConditionalGMM/condGMM.py b/src/ConditionalGMM/condGMM.py index f6b384b..66d03b1 100644 --- a/src/ConditionalGMM/condGMM.py +++ b/src/ConditionalGMM/condGMM.py @@ -1,7 +1,7 @@ """Conditional Gaussian mixture model.""" import numpy as np -import scipy as sp +from scipy.stats import multivariate_normal from ConditionalGMM.MNorm import CondMNorm @@ -103,7 +103,7 @@ def unconditional_pdf_x2(self, x2=None, component_probs=False): probs = w * np.array( [ - sp.stats.multivariate_normal.pdf(x2, mean=mus[i], cov=covs[i]) + multivariate_normal.pdf(x2, mean=mus[i], cov=covs[i]) for i in range(len(w)) ] ) From fc631e5ccd5c08ac32c493caacb0851f563c4c58 Mon Sep 17 00:00:00 2001 From: tmcclintock Date: Fri, 9 May 2025 23:41:08 -0700 Subject: [PATCH 2/3] fix: type check on test file --- tests/test_UGMM.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_UGMM.py b/tests/test_UGMM.py index d8267c8..f8f2780 100644 --- a/tests/test_UGMM.py +++ b/tests/test_UGMM.py @@ -2,7 +2,7 @@ import numpy as np import numpy.testing as npt -import scipy as sp +from scipy.stats import norm from ConditionalGMM import UnivariateGMM @@ -25,7 +25,7 @@ def test_uggm_pdf(): pdf = ugmm.pdf(x) truepdf = np.dot( np.array(weights), - np.array([sp.stats.norm.pdf(x, mi, vi) for mi, vi in zip(means, vars)]), + np.array([norm.pdf(x, mi, vi) for mi, vi in zip(means, vars)]), ) npt.assert_equal(pdf, truepdf) From 95f1d4d3e48d4512482ab08317b9a5cb340b7743 Mon Sep 17 00:00:00 2001 From: tmcclintock Date: Fri, 9 May 2025 23:42:15 -0700 Subject: [PATCH 3/3] feat: type check with ty in ci --- .github/workflows/ci.yaml | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c590769..b879483 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -24,7 +24,14 @@ jobs: - name: Install the project and dependencies run: uv sync --all-extras - - uses: astral-sh/ruff-action@v3 + - name: Lint + uses: astral-sh/ruff-action@v3 + + - name: Type checkg + run: | + uv tool install ty + ty check src + ty check tests - name: Run tests with coverage run: uv run coverage run -m pytest