Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions KDEpy/BaseKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numbers
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Callable, Optional, Union

import numpy as np

Expand Down Expand Up @@ -33,17 +34,17 @@ class BaseKDE(ABC):
_bw_methods = _bw_methods

@abstractmethod
def __init__(self, kernel: str, bw: float):
def __init__(self, kernel: Union[str, Callable], bw: Union[float, str, np.ndarray]):
"""Initialize the kernel density estimator.

The return type must be duplicated in the docstring to comply
with the NumPy docstring style.

Parameters
----------
kernel
Kernel function, or string matching available options.
bw
kernel : str or callable
Kernel function, or string matching available options. See cls._available_kernels.keys() for choices.
bw : float, str or array-like
The bandwidth, either a number, a string or an array-like.
"""

Expand Down Expand Up @@ -81,7 +82,7 @@ def __init__(self, kernel: str, bw: float):
assert isinstance(self.bw_method, (np.ndarray, Sequence, numbers.Number)) or callable(self.bw_method)

@abstractmethod
def fit(self, data, weights=None):
def fit(self, data: np.ndarray, weights: Optional[np.ndarray] = None) -> "BaseKDE":
"""
Fit the kernel density estimator to the data.
This method converts the data to shape (obs, dims) and the weights
Expand All @@ -95,6 +96,11 @@ def fit(self, data, weights=None):
weights : array-like, Sequence or None
May be array-like of shape (obs,), shape (obs, dims), a
Python Sequence, e.g. a list or tuple, or None.

Returns
-------
self
Returns the instance.
"""

# -------------- Set up the data depending on input -------------------
Expand Down Expand Up @@ -132,8 +138,12 @@ def fit(self, data, weights=None):
else:
self.bw = self.bw_method

return self

@abstractmethod
def evaluate(self, grid_points=None, bw_to_scalar=True):
def evaluate(
self, grid_points: Optional[Union[np.ndarray, int, tuple]] = None, bw_to_scalar: bool = True
) -> Union[np.ndarray, tuple]:
"""
Evaluate the kernel density estimator on the grid points.

Expand Down
11 changes: 6 additions & 5 deletions KDEpy/FFTKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Module for the FFTKDE.
"""
import numbers
from typing import Callable, Optional, Union
import warnings

import numpy as np
Expand Down Expand Up @@ -42,8 +43,8 @@ class FFTKDE(BaseKDE):

Parameters
----------
kernel : str
The kernel function. See cls._available_kernels.keys() for choices.
kernel : str or callable
Kernel function, or string matching available options. See cls._available_kernels.keys() for choices.
bw : float or str
Bandwidth or bandwidth selection method. If a float is passed, it
is the standard deviation of the kernel. If a string it passed, it
Expand Down Expand Up @@ -72,12 +73,12 @@ class FFTKDE(BaseKDE):

"""

def __init__(self, kernel="gaussian", bw=1, norm=2):
def __init__(self, kernel: Union[str, Callable] = "gaussian", bw: Union[float, str] = 1, norm: int = 2):
self.norm = norm
super().__init__(kernel, bw)
assert isinstance(self.norm, numbers.Number) and self.norm > 0

def fit(self, data, weights=None):
def fit(self, data: np.ndarray, weights: Optional[np.ndarray] = None) -> "FFTKDE":
"""
Fit the KDE to the data. This validates the data and stores it.
Computations are performed upon evaluation on a specific grid.
Expand Down Expand Up @@ -107,7 +108,7 @@ def fit(self, data, weights=None):
super().fit(data, weights)
return self

def evaluate(self, grid_points=None):
def evaluate(self, grid_points: Optional[Union[np.ndarray, int, tuple]] = None) -> Union[np.ndarray, tuple]:
"""
Evaluate on equidistant grid points.

Expand Down
13 changes: 8 additions & 5 deletions KDEpy/NaiveKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""
import itertools
import numbers
from typing import Callable, Optional, Union

import numpy as np

Expand All @@ -20,8 +21,8 @@ class NaiveKDE(BaseKDE):

Parameters
----------
kernel : str
The kernel function. See cls._available_kernels.keys() for choices.
kernel : str or callable
Kernel function, or string matching available options. See cls._available_kernels.keys() for choices.
bw : float, str or array-like
Bandwidth or bandwidth selection method. If a float is passed, it
is the standard deviation of the kernel. If a string it passed, it
Expand Down Expand Up @@ -51,11 +52,13 @@ class NaiveKDE(BaseKDE):
- Scipy implementation, at ``scipy.stats.gaussian_kde``.
"""

def __init__(self, kernel="gaussian", bw=1, norm=2):
def __init__(
self, kernel: Union[str, Callable] = "gaussian", bw: Union[float, str, np.ndarray] = 1, norm: float = 2
):
super().__init__(kernel, bw)
self.norm = norm

def fit(self, data, weights=None):
def fit(self, data: np.ndarray, weights: Optional[np.ndarray] = None) -> "NaiveKDE":
"""
Fit the KDE to the data. This validates the data and stores it.
Computations are performed when the KDE is evaluated on a grid.
Expand Down Expand Up @@ -85,7 +88,7 @@ def fit(self, data, weights=None):
super().fit(data, weights)
return self

def evaluate(self, grid_points=None):
def evaluate(self, grid_points: Optional[Union[np.ndarray, int, tuple]] = None) -> Union[np.ndarray, tuple]:
"""
Evaluate on grid points.

Expand Down
15 changes: 10 additions & 5 deletions KDEpy/TreeKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Module for the TreeKDE.
"""
import numbers
from typing import Callable, Optional, Union

import numpy as np
from scipy.spatial import cKDTree
Expand All @@ -24,8 +25,8 @@ class TreeKDE(BaseKDE):

Parameters
----------
kernel : str
The kernel function. See cls._available_kernels.keys() for choices.
kernel : str or callable
Kernel function, or string matching available options. See cls._available_kernels.keys() for choices.
bw : float, str or array-like
Bandwidth or bandwidth selection method. If a float is passed, it
is the standard deviation of the kernel. If a string it passed, it
Expand Down Expand Up @@ -61,11 +62,13 @@ class TreeKDE(BaseKDE):
- Scipy implementation, at ``scipy.spatial.KDTree``.
"""

def __init__(self, kernel="gaussian", bw=1, norm=2.0):
def __init__(
self, kernel: Union[str, Callable] = "gaussian", bw: Union[float, str, np.ndarray] = 1, norm: float = 2.0
):
super().__init__(kernel, bw)
self.norm = norm

def fit(self, data, weights=None):
def fit(self, data: np.ndarray, weights: Optional[np.ndarray] = None) -> "TreeKDE":
"""
Fit the KDE to the data. This validates the data and stores it.
Computations are performed upon evaluation on a grid.
Expand Down Expand Up @@ -95,7 +98,9 @@ def fit(self, data, weights=None):
super().fit(data, weights)
return self

def evaluate(self, grid_points=None, eps=10e-4):
def evaluate(
self, grid_points: Optional[Union[np.ndarray, int, tuple]] = None, eps: float = 10e-4
) -> Union[np.ndarray, tuple]:
"""
Evaluate on grid points.

Expand Down
Loading