Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
- [x] Bernoulli
- [x] Beta
- [x] Binomial
- [ ] Dirichlet
- [x] Dirichlet
- [x] Gamma
- [x] Student's t
- [x] Uniform
Expand Down
212 changes: 210 additions & 2 deletions src/statistics/dist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
//! * Uniform
//! * Weighted Uniform
//! * Log Normal
//! * There are two enums to represent probability distribution
//! * There are three enums to represent probability distribution
//! * `OPDist<T>` : One parameter distribution (Bernoulli)
//! * `TPDist<T>` : Two parameter distribution (Uniform, Normal, Beta, Gamma)
//! * `MVDist<T>` : Multivariate distribution (Dirichlet)
//! * `T: PartialOrd + SampleUniform + Copy + Into<f64>`
//! * There are some traits for pdf
//! * `RNG` trait - extract sample & calculate pdf
//! * `RNG` trait - extract sample & calculate pdf for 1D distributions
//! * `MVRNG` trait - extract sample & calculate pdf for multivariate distributions
//! * `Statistics` trait - already shown above
//!
//! ### `RNG` trait
Expand Down Expand Up @@ -239,6 +241,55 @@
//! * Mean: $e^{\mu + \frac{\sigma^2}{2}}$
//! * Var: $(e^{\sigma^2} - 1)e^{2\mu + \sigma^2}$
//! * To generate log-normal random samples, Peroxide uses the `rand_distr::LogNormal` distribution from the `rand_distr` crate.
//! ### `MVRNG` trait
//!
//! * `MVRNG` trait is composed of four fields
//! * `sample`: Extract samples
//! * `sample_with_rng`: Extract samples with specific rng
//! * `pdf` : Calculate pdf value at specific point
//! * `ln_pdf` : Calculate log-pdf value at specific point
//! ```no_run
//! use rand::Rng;
//! pub trait MVRNG {
//! /// Extract samples of multivariate distributions
//! fn sample(&self, n: usize) -> Matrix;
//!
//! /// Extract samples of distributions with specific rng
//! fn sample_with_rng<R: Rng + Clone>(&self, rng: &mut R, n: usize) -> Matrix;
//!
//! /// Probability Density Function
//! fn pdf(&self, x: &[f64]) -> f64;
//!
//! /// Log Probability Density Function
//! fn ln_pdf(&self, x: &[f64]) -> f64;
//! }
//! ```
//!
//! ### Dirichlet Distribution
//!
//! * Definition
//! $$ \text{Dir}(\mathbf{x} | \boldsymbol{\alpha}) = \frac{1}{\text{B}(\boldsymbol{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1} $$
//! where $\text{B}(\boldsymbol{\alpha}) = \frac{\prod_{i=1}^K \Gamma(\alpha_i)}{\Gamma(\sum_{i=1}^K \alpha_i)}$
//! * Representative value
//! * Mean: $\frac{\alpha_i}{\alpha_0}$
//! * Var : $\frac{\alpha_i(\alpha_0 - \alpha_i)}{\alpha_0^2(\alpha_0 + 1)}$
//! * To generate Dirichlet random samples, Peroxide generates $K$ independent Gamma samples and normalizes them.
//! * **Caution**: `MVDist` utilizes the existing `Statistics` trait but outputs vectors and matrices.
//!
//! ```rust
//! use peroxide::fuga::*;
//!
//! fn main() {
//! let mut rng = smallrng_from_seed(42);
//! let a = Dirichlet(vec![1.0, 2.0, 3.0]); // Dir(x | 1.0, 2.0, 3.0)
//! a.sample(100).print(); // Generate 100 samples
//! a.sample_with_rng(&mut rng, 100).print(); // Generate 100 samples with specific rng
//! a.pdf(&[0.16, 0.33, 0.51]).print(); // Probability density
//! a.mean().print(); // Mean vector
//! a.var().print(); // Variance vector
//! a.cov().print(); // Covariance matrix
//! }
//! ```

extern crate rand;
extern crate rand_distr;
Expand All @@ -255,6 +306,7 @@ use self::WeightedUniformError::*;
use crate::statistics::{ops::C, stat::Statistics};
use crate::util::non_macro::{linspace, seq};
use crate::util::useful::{auto_zip, find_interval};
use crate::structure::matrix::{matrix, Matrix, Row};
use anyhow::{bail, Result};
use std::f64::consts::E;

Expand Down Expand Up @@ -283,6 +335,15 @@ pub enum TPDist<T: PartialOrd + SampleUniform + Copy + Into<f64>> {
LogNormal(T, T),
}

/// Multivariate Distribution
///
/// # Distributions
/// * `Dirichlet(alpha)`: Dirichlet distribution
#[derive(Debug, Clone)]
pub enum MVDist<T: PartialOrd + SampleUniform + Copy + Into<f64>> {
Dirichlet(Vec<T>),
}

pub struct WeightedUniform<T: PartialOrd + SampleUniform + Copy + Into<f64>> {
weights: Vec<T>,
sum: T,
Expand Down Expand Up @@ -1000,3 +1061,150 @@ impl Statistics for WeightedUniform<f64> {
vec![1f64]
}
}

/// Multivariate Random Number Generator Trait
pub trait MVRNG {
/// Extract samples of multivariate distributions (Returns an n x k Matrix)
fn sample(&self, n: usize) -> Matrix {
let mut rng = rand::rng();
self.sample_with_rng(&mut rng, n)
}

/// Extract samples of distributions with specific rng
fn sample_with_rng<R: Rng + Clone>(&self, rng: &mut R, n: usize) -> Matrix;

/// Probability Density Function
fn pdf(&self, x: &[f64]) -> f64 {
self.ln_pdf(x).exp()
}

/// Log Probability Density Function
fn ln_pdf(&self, x: &[f64]) -> f64;
}

impl<T: PartialOrd + SampleUniform + Copy + Into<f64>> Statistics for MVDist<T> {
type Array = Matrix;
type Value = Vec<f64>;

fn mean(&self) -> Self::Value {
match self {
MVDist::Dirichlet(alpha_t) => {
let alpha: Vec<f64> = alpha_t.iter().map(|&a| a.into()).collect();
let alpha0: f64 = alpha.iter().sum();
alpha.iter().map(|&a| a / alpha0).collect()
}
}
}

fn var(&self) -> Self::Value {
match self {
MVDist::Dirichlet(alpha_t) => {
let alpha: Vec<f64> = alpha_t.iter().map(|&a| a.into()).collect();
let alpha0: f64 = alpha.iter().sum();
let norm = alpha0.powi(2) * (alpha0 + 1.0);
alpha.iter().map(|&a| a * (alpha0 - a) / norm).collect()
}
}
}

fn sd(&self) -> Self::Value {
self.var().into_iter().map(|v| v.sqrt()).collect()
}

fn cov(&self) -> Self::Array {
match self {
MVDist::Dirichlet(alpha_t) => {
let alpha: Vec<f64> = alpha_t.iter().map(|&a| a.into()).collect();
let alpha0: f64 = alpha.iter().sum();
let k = alpha.len();
let norm = alpha0.powi(2) * (alpha0 + 1.0);
let mut cov_data = vec![0f64; k * k];

for i in 0..k {
for j in 0..k {
let idx = i * k + j;
if i == j {
cov_data[idx] = alpha[i] * (alpha0 - alpha[i]) / norm;
} else {
cov_data[idx] = -alpha[i] * alpha[j] / norm;
}
}
}

matrix(cov_data, k, k, Row)
}
}
}

fn cor(&self) -> Self::Array {
let cov_matrix = self.cov();
let sd_vec = self.sd();
let k = sd_vec.len();

let mut cor_data = vec![0f64; k * k];

for i in 0..k {
for j in 0..k {
let idx = i * k + j;
cor_data[idx] = cov_matrix[(i, j)] / (sd_vec[i] * sd_vec[j]);
}
}
matrix(cor_data, k, k, Row)
}
}

impl<T: PartialOrd + SampleUniform + Copy + Into<f64>> MVRNG for MVDist<T> {
fn sample_with_rng<R: Rng + Clone>(&self, rng: &mut R, n: usize) -> Matrix {
match self {
MVDist::Dirichlet(alpha_t) => {
let alpha: Vec<f64> = alpha_t.iter().map(|&a| a.into()).collect();
let k = alpha.len();
let mut sample_data = vec![0f64; n * k];

for i in 0..n {
let mut sum = 0f64;
let mut y = vec![0f64; k];

for j in 0..k {
let gamma_dist = rand_distr::Gamma::new(alpha[j], 1.0).unwrap();
y[j] = gamma_dist.sample(rng);
sum += y[j];
}

for j in 0..k {
sample_data[i * k + j] = y[j] / sum;
}
}

matrix(sample_data, n, k, Row)
}
}
}

fn ln_pdf(&self, x: &[f64]) -> f64 {
match self {
MVDist::Dirichlet(alpha_t) => {
let alpha: Vec<f64> = alpha_t.iter().map(|&a| a.into()).collect();
assert_eq!(alpha.len(), x.len(), "Arguments must have correct dimensions.");

let mut term = 0f64;
let mut sum_x = 0f64;
let mut sum_alpha_ln_gamma = 0f64;
let mut alpha0 = 0f64;

for (&x_i, &alpha_i) in x.iter().zip(alpha.iter()) {
assert!(x_i > 0f64 && x_i < 1f64, "Arguments must be in (0, 1)");

term += (alpha_i - 1.0) * x_i.ln();
sum_alpha_ln_gamma += gamma(alpha_i).ln();
sum_x += x_i;
alpha0 += alpha_i;
}

assert!((sum_x - 1.0).abs() < 1e-4, "Arguments must sum up to 1");

term + gamma(alpha0).ln() - sum_alpha_ln_gamma
}
}
}
}
6 changes: 6 additions & 0 deletions src/util/print.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,12 @@ impl<T: Debug + PartialOrd + SampleUniform + Copy + Into<f64>> Printable for TPD
}
}

impl<T: Debug + PartialOrd + SampleUniform + Copy + Into<f64>> Printable for MVDist<T> {
fn print(&self) {
println!("{:?}", self);
}
}

//impl Printable for Number {
// fn print(&self) {
// println!("{:?}", self)
Expand Down
19 changes: 19 additions & 0 deletions tests/dist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,22 @@ fn test_binomial() {
assert!(nearly_eq(b.mean(), 80f64));
assert!(nearly_eq(b.var(), 16f64));
}

#[test]
fn test_dirichlet() {
let dir = MVDist::Dirichlet(vec![1.0, 2.0, 3.0]);
dir.sample(10).print();

let m = dir.mean();
assert!(nearly_eq(m[0], 1.0 / 6.0));
assert!(nearly_eq(m[1], 1.0 / 3.0));
assert!(nearly_eq(m[2], 0.5));

let v = dir.var();
assert!(nearly_eq(v[0], 5.0 / 252.0)); // 1 * 5 / (36 * 7)
assert!(nearly_eq(v[1], 8.0 / 252.0)); // 2 * 4 / (36 * 7)
assert!(nearly_eq(v[2], 9.0 / 252.0)); // 3 * 3 / (36 * 7)

let pdf_val = dir.pdf(&[0.33333, 0.33333, 0.33333]);
assert!(nearly_eq(pdf_val, 2.222155556222205));
}