From 4c2b202f04d4012babff42aba5fcb0f83498c794 Mon Sep 17 00:00:00 2001 From: Pim Meulensteen Date: Wed, 7 Jan 2026 15:56:50 +0100 Subject: [PATCH] Implement type hints for exposed classes. --- KDEpy/BaseKDE.py | 22 ++++++++++++++++------ KDEpy/FFTKDE.py | 11 ++++++----- KDEpy/NaiveKDE.py | 13 ++++++++----- KDEpy/TreeKDE.py | 15 ++++++++++----- 4 files changed, 40 insertions(+), 21 deletions(-) diff --git a/KDEpy/BaseKDE.py b/KDEpy/BaseKDE.py index 6c46ac4..e0e6bce 100644 --- a/KDEpy/BaseKDE.py +++ b/KDEpy/BaseKDE.py @@ -6,6 +6,7 @@ import numbers from abc import ABC, abstractmethod from collections.abc import Sequence +from typing import Callable, Optional, Union import numpy as np @@ -33,7 +34,7 @@ class BaseKDE(ABC): _bw_methods = _bw_methods @abstractmethod - def __init__(self, kernel: str, bw: float): + def __init__(self, kernel: Union[str, Callable], bw: Union[float, str, np.ndarray]): """Initialize the kernel density estimator. The return type must be duplicated in the docstring to comply @@ -41,9 +42,9 @@ def __init__(self, kernel: str, bw: float): Parameters ---------- - kernel - Kernel function, or string matching available options. - bw + kernel : str or callable + Kernel function, or string matching available options. See cls._available_kernels.keys() for choices. + bw : float, str or array-like The bandwidth, either a number, a string or an array-like. """ @@ -81,7 +82,7 @@ def __init__(self, kernel: str, bw: float): assert isinstance(self.bw_method, (np.ndarray, Sequence, numbers.Number)) or callable(self.bw_method) @abstractmethod - def fit(self, data, weights=None): + def fit(self, data: np.ndarray, weights: Optional[np.ndarray] = None) -> "BaseKDE": """ Fit the kernel density estimator to the data. This method converts the data to shape (obs, dims) and the weights @@ -95,6 +96,11 @@ def fit(self, data, weights=None): weights : array-like, Sequence or None May be array-like of shape (obs,), shape (obs, dims), a Python Sequence, e.g. a list or tuple, or None. + + Returns + ------- + self + Returns the instance. """ # -------------- Set up the data depending on input ------------------- @@ -132,8 +138,12 @@ def fit(self, data, weights=None): else: self.bw = self.bw_method + return self + @abstractmethod - def evaluate(self, grid_points=None, bw_to_scalar=True): + def evaluate( + self, grid_points: Optional[Union[np.ndarray, int, tuple]] = None, bw_to_scalar: bool = True + ) -> Union[np.ndarray, tuple]: """ Evaluate the kernel density estimator on the grid points. diff --git a/KDEpy/FFTKDE.py b/KDEpy/FFTKDE.py index 77c28bb..0ed848f 100644 --- a/KDEpy/FFTKDE.py +++ b/KDEpy/FFTKDE.py @@ -4,6 +4,7 @@ Module for the FFTKDE. """ import numbers +from typing import Callable, Optional, Union import warnings import numpy as np @@ -42,8 +43,8 @@ class FFTKDE(BaseKDE): Parameters ---------- - kernel : str - The kernel function. See cls._available_kernels.keys() for choices. + kernel : str or callable + Kernel function, or string matching available options. See cls._available_kernels.keys() for choices. bw : float or str Bandwidth or bandwidth selection method. If a float is passed, it is the standard deviation of the kernel. If a string it passed, it @@ -72,12 +73,12 @@ class FFTKDE(BaseKDE): """ - def __init__(self, kernel="gaussian", bw=1, norm=2): + def __init__(self, kernel: Union[str, Callable] = "gaussian", bw: Union[float, str] = 1, norm: int = 2): self.norm = norm super().__init__(kernel, bw) assert isinstance(self.norm, numbers.Number) and self.norm > 0 - def fit(self, data, weights=None): + def fit(self, data: np.ndarray, weights: Optional[np.ndarray] = None) -> "FFTKDE": """ Fit the KDE to the data. This validates the data and stores it. Computations are performed upon evaluation on a specific grid. @@ -107,7 +108,7 @@ def fit(self, data, weights=None): super().fit(data, weights) return self - def evaluate(self, grid_points=None): + def evaluate(self, grid_points: Optional[Union[np.ndarray, int, tuple]] = None) -> Union[np.ndarray, tuple]: """ Evaluate on equidistant grid points. diff --git a/KDEpy/NaiveKDE.py b/KDEpy/NaiveKDE.py index eb51d9d..33c1405 100644 --- a/KDEpy/NaiveKDE.py +++ b/KDEpy/NaiveKDE.py @@ -5,6 +5,7 @@ """ import itertools import numbers +from typing import Callable, Optional, Union import numpy as np @@ -20,8 +21,8 @@ class NaiveKDE(BaseKDE): Parameters ---------- - kernel : str - The kernel function. See cls._available_kernels.keys() for choices. + kernel : str or callable + Kernel function, or string matching available options. See cls._available_kernels.keys() for choices. bw : float, str or array-like Bandwidth or bandwidth selection method. If a float is passed, it is the standard deviation of the kernel. If a string it passed, it @@ -51,11 +52,13 @@ class NaiveKDE(BaseKDE): - Scipy implementation, at ``scipy.stats.gaussian_kde``. """ - def __init__(self, kernel="gaussian", bw=1, norm=2): + def __init__( + self, kernel: Union[str, Callable] = "gaussian", bw: Union[float, str, np.ndarray] = 1, norm: float = 2 + ): super().__init__(kernel, bw) self.norm = norm - def fit(self, data, weights=None): + def fit(self, data: np.ndarray, weights: Optional[np.ndarray] = None) -> "NaiveKDE": """ Fit the KDE to the data. This validates the data and stores it. Computations are performed when the KDE is evaluated on a grid. @@ -85,7 +88,7 @@ def fit(self, data, weights=None): super().fit(data, weights) return self - def evaluate(self, grid_points=None): + def evaluate(self, grid_points: Optional[Union[np.ndarray, int, tuple]] = None) -> Union[np.ndarray, tuple]: """ Evaluate on grid points. diff --git a/KDEpy/TreeKDE.py b/KDEpy/TreeKDE.py index 197c231..afcea92 100644 --- a/KDEpy/TreeKDE.py +++ b/KDEpy/TreeKDE.py @@ -4,6 +4,7 @@ Module for the TreeKDE. """ import numbers +from typing import Callable, Optional, Union import numpy as np from scipy.spatial import cKDTree @@ -24,8 +25,8 @@ class TreeKDE(BaseKDE): Parameters ---------- - kernel : str - The kernel function. See cls._available_kernels.keys() for choices. + kernel : str or callable + Kernel function, or string matching available options. See cls._available_kernels.keys() for choices. bw : float, str or array-like Bandwidth or bandwidth selection method. If a float is passed, it is the standard deviation of the kernel. If a string it passed, it @@ -61,11 +62,13 @@ class TreeKDE(BaseKDE): - Scipy implementation, at ``scipy.spatial.KDTree``. """ - def __init__(self, kernel="gaussian", bw=1, norm=2.0): + def __init__( + self, kernel: Union[str, Callable] = "gaussian", bw: Union[float, str, np.ndarray] = 1, norm: float = 2.0 + ): super().__init__(kernel, bw) self.norm = norm - def fit(self, data, weights=None): + def fit(self, data: np.ndarray, weights: Optional[np.ndarray] = None) -> "TreeKDE": """ Fit the KDE to the data. This validates the data and stores it. Computations are performed upon evaluation on a grid. @@ -95,7 +98,9 @@ def fit(self, data, weights=None): super().fit(data, weights) return self - def evaluate(self, grid_points=None, eps=10e-4): + def evaluate( + self, grid_points: Optional[Union[np.ndarray, int, tuple]] = None, eps: float = 10e-4 + ) -> Union[np.ndarray, tuple]: """ Evaluate on grid points.