Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 81 additions & 4 deletions KDEpy/BaseKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from KDEpy.kernel_funcs import _kernel_functions
from KDEpy.bw_selection import _bw_methods
from KDEpy.bw_selection import _bw_methods, cross_val
from KDEpy.utils import autogrid


Expand All @@ -31,7 +31,7 @@ class BaseKDE(ABC):
_bw_methods = _bw_methods

@abstractmethod
def __init__(self, kernel: str, bw: float):
def __init__(self, kernel: str, bw: float, norm: float):
"""Initialize the kernel density estimator.

The return type must be duplicated in the docstring to comply
Expand Down Expand Up @@ -59,6 +59,9 @@ def __init__(self, kernel: str, bw: float):
else:
raise ValueError(msg)

# CV method must be added here since it depends on self
_bw_methods["CV"] = self.cross_val

# The `bw` paramter may either be a positive number, a string, or
# array-like such that each point in the data has a uniue bw
if isinstance(bw, numbers.Number) and bw > 0:
Expand All @@ -74,12 +77,15 @@ def __init__(self, kernel: str, bw: float):
else:
raise ValueError("Bandwidth must be > 0, array-like or a string.")

self.norm = norm

# Test quickly that the method has done what is was supposed to do
assert callable(self.kernel)
assert isinstance(self.bw_method, (np.ndarray, Sequence, numbers.Number)) or callable(self.bw_method)
assert isinstance(self.norm, numbers.Number) and self.norm > 0

@abstractmethod
def fit(self, data, weights=None):
def fit(self, data, weights=None, **kwargs):
"""
Fit the kernel density estimator to the data.
This method converts the data to shape (obs, dims) and the weights
Expand All @@ -93,6 +99,8 @@ def fit(self, data, weights=None):
weights : array-like, Sequence or None
May be array-like of shape (obs,), shape (obs, dims), a
Python Sequence, e.g. a list or tuple, or None.
**kwargs:
List of arguments to be passed to bandwidth optimization method.
"""

# -------------- Set up the data depending on input -------------------
Expand Down Expand Up @@ -126,7 +134,7 @@ def fit(self, data, weights=None):
if isinstance(self.bw_method, (np.ndarray, Sequence)):
self.bw = self.bw_method
elif callable(self.bw_method):
self.bw = self.bw_method(self.data, self.weights)
self.bw = self.bw_method(self.data, self.weights, **kwargs)
else:
self.bw = self.bw_method

Expand Down Expand Up @@ -175,6 +183,75 @@ def evaluate(self, grid_points=None, bw_to_scalar=True):
assert bw > 0
assert len(self.grid_points.shape) == 2

def score(self, test_data, test_weights=None):
"""
Computes the score of test data on the KDE model. The score is
calculated as the mean log-probability of the test samples
on the model. The method takes into account test weights, and
works with variable bandwidths.

Parameters
----------
test_data : array-like or Sequence
May be array-like of shape (obs,), shape (obs, dims) or a
Python Sequence, e.g. a list or tuple.
test_weights : array-like, Sequence or None
May be array-like of shape (obs,), shape (obs, dims), a
Python Sequence, e.g. a list or tuple, or None.
"""

# -------------- Set up the data depending on input -------------------
# In the end, the data should be an ndarray of shape (obs, dims)
test_data = self._process_sequence(test_data)

obs, dims = test_data.shape

if not obs > 0:
raise ValueError("Test data must contain at least one data point.")
assert dims > 0

# -------------- Set up the weights depending on input ----------------
if test_weights is not None:
test_weights = self._process_sequence(test_weights).ravel()
if not obs == len(test_weights):
raise ValueError("Number of test data obs must match test weights")

return np.mean(test_weights * np.log(self.evaluate(test_data)))

return np.mean(np.log(self.evaluate(test_data)))

def cross_val(self, data, weights=None, cv=10, seed=None, grid=None):
"""
Computes the cross validated score over a grid of bandwidths, and returns
the one that maximizes it. It is a robust method against multimodal
distributions, and can be performed on variable bandwidths (e.g.: by
setting "seed" parameter as the output of k nearest neighbors algorithm).

Habbema, J. D. F., Hermans, J., and Van den Broek, K. (1974) A stepwise
discrimination analysis program using density estimation.

Leave-one-out MLCV method in R: https://rdrr.io/cran/kedd/man/h.mlcv.html

Parameters
----------
data: array-like
The data points. Data must have shape (obs, dims).
weights: array-like,
One weight per data point. Numbers of observations must match
the data points.
cv: int
The number of cross validation folds. If cv equals obs, it is the
leave-one-out cross validation.
seed : float or array-like
The seed bandwidth. By default is a simplified version of the silverman
rule.
grid : array-like
The grid of factors. The bandwidth grid is constructed as:
bw_grid[i] = bw * grid[i]
By default is np.logspace(-1,1,20)
"""
return cross_val(self, data, weights=weights, cv=cv, seed=seed, grid=grid)

@staticmethod
def _process_sequence(sequence_array_like):
"""
Expand Down
10 changes: 5 additions & 5 deletions KDEpy/FFTKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ class FFTKDE(BaseKDE):
"""

def __init__(self, kernel="gaussian", bw=1, norm=2):
self.norm = norm
super().__init__(kernel, bw)
assert isinstance(self.norm, numbers.Number) and self.norm > 0
super().__init__(kernel, bw, norm)

def fit(self, data, weights=None):
def fit(self, data, weights=None, **kwargs):
"""
Fit the KDE to the data. This validates the data and stores it.
Computations are performed upon evaluation on a specific grid.
Expand All @@ -83,6 +81,8 @@ def fit(self, data, weights=None):
The data points.
weights: array-like
One weight per data point. Must have same shape as the data.
**kwargs:
List of arguments to be passed to bandwidth optimization method.

Returns
-------
Expand All @@ -99,7 +99,7 @@ def fit(self, data, weights=None):
"""

# Sets self.data
super().fit(data, weights)
super().fit(data, weights, **kwargs)
return self

def evaluate(self, grid_points=None):
Expand Down
11 changes: 6 additions & 5 deletions KDEpy/NaiveKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class NaiveKDE(BaseKDE):
bw : float, str or array-like
Bandwidth or bandwidth selection method. If a float is passed, it
is the standard deviation of the kernel. If a string it passed, it
is the bandwidth selection method, see cls._bw_methods.keys() for
is the bandwidth selection method, see cls()._bw_methods.keys() for
choices. If an array-like it passed, it is the bandwidth of each
point.
norm : float
Expand All @@ -50,10 +50,9 @@ class NaiveKDE(BaseKDE):
"""

def __init__(self, kernel="gaussian", bw=1, norm=2):
super().__init__(kernel, bw)
self.norm = norm
super().__init__(kernel, bw, norm)

def fit(self, data, weights=None):
def fit(self, data, weights=None, **kwargs):
"""
Fit the KDE to the data. This validates the data and stores it.
Computations are performed when the KDE is evaluated on a grid.
Expand All @@ -65,6 +64,8 @@ def fit(self, data, weights=None):
weights: array-like
One weight per data point. Must have shape (obs,). If None is
passed, uniform weights are used.
**kwargs:
List of arguments to be passed to bandwidth optimization method.

Returns
-------
Expand All @@ -80,7 +81,7 @@ def fit(self, data, weights=None):
>>> x, y = kde()
"""
# Sets self.data
super().fit(data, weights)
super().fit(data, weights, **kwargs)
return self

def evaluate(self, grid_points=None):
Expand Down
11 changes: 6 additions & 5 deletions KDEpy/TreeKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TreeKDE(BaseKDE):
bw : float, str or array-like
Bandwidth or bandwidth selection method. If a float is passed, it
is the standard deviation of the kernel. If a string it passed, it
is the bandwidth selection method, see cls._bw_methods.keys() for
is the bandwidth selection method, see cls()._bw_methods.keys() for
choices. If an array-like it passed, it is the bandwidth of each
point.
norm : float
Expand Down Expand Up @@ -60,10 +60,9 @@ class TreeKDE(BaseKDE):
"""

def __init__(self, kernel="gaussian", bw=1, norm=2.0):
super().__init__(kernel, bw)
self.norm = norm
super().__init__(kernel, bw, norm)

def fit(self, data, weights=None):
def fit(self, data, weights=None, **kwargs):
"""
Fit the KDE to the data. This validates the data and stores it.
Computations are performed upon evaluation on a grid.
Expand All @@ -75,6 +74,8 @@ def fit(self, data, weights=None):
weights: array-like
One weight per data point. Numbers of observations must match
the data points.
**kwargs:
List of arguments to be passed to bandwidth optimization method.

Returns
-------
Expand All @@ -90,7 +91,7 @@ def fit(self, data, weights=None):
>>> x, y = kde()
"""
# Sets self.data
super().fit(data, weights)
super().fit(data, weights, **kwargs)
return self

def evaluate(self, grid_points=None, eps=10e-4):
Expand Down
5 changes: 1 addition & 4 deletions KDEpy/binning.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,10 +268,7 @@ def linbin_Ndim_python(data, grid_points, weights=None):
# Compute integer part and fractional part for every x_i
# Compute relation to previous grid point, and next grid point
int_frac = (
(
(int(coordinate), 1 - (coordinate % 1)),
(int(coordinate) + 1, (coordinate % 1)),
)
((int(coordinate), 1 - (coordinate % 1)), (int(coordinate) + 1, (coordinate % 1)))
for coordinate in observation
)

Expand Down
Loading