Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 77 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,84 @@ To install Dolphin, first clone the repository. Then use pip:
pip install -e .
```

## Tutorial

For a complete walkthrough, see **[tutorial.ipynb](tutorial.ipynb)** or **[tutorial.py](tutorial.py)**, which train a model to compute `(a + b)² // c` from MNIST digit images.

### Quick Start

```python
import dolphin
from dolphin import Distribution
from dolphin.provenances import get_provenance

# Configure provenance
Distribution.provenance = get_provenance("damp")
```

**1. Wrap neural networks with `@dolphin.distribution`** — outputs become Distributions over symbols:

```python
@dolphin.distribution(range(10)) # Symbols are digits 0-9
class DigitNet(nn.Module):
def forward(self, x):
# ... your network ...
return F.softmax(logits, dim=1) # Automatically wrapped
```

**2. Define symbolic operations with `@dolphin.function`** — write plain Python, Dolphin handles the probabilities:

```python
@dolphin.function
def formula(a, b, c):
return (a + b) ** 2 // c
```

**3. Combine them in a model:**

```python
class FormulaModel(nn.Module):
def __init__(self):
super().__init__()
self.digit_net = DigitNet()

def forward(self, images):
a = self.digit_net(images[0])
b = self.digit_net(images[1])
c = self.digit_net(images[2])

result = formula(a, b, c, if_=lambda a, b, c: c != 0) # skip division by zero
return result.get_probabilities()
```

**4. Train with standard PyTorch** — gradients flow through symbolic ops:

```python
output = model(images)
loss = F.cross_entropy(output, targets)
loss.backward() # end-to-end differentiable!
optimizer.step()
```

### Other Features

```python
# Conditional computation
result = formula(a, b, c, if_=lambda a, b, c: c != 0)

# Map output to fixed size
result.map_symbols(list(range(num_classes))).get_probabilities()

# Direct Distribution operations
d1 + d2 # add
d1 * d2 # multiply
d1.map(lambda x: x * 2) # transform symbols
d1.filter(lambda x: x > 0) # filter symbols
```

## Running Experiments

To run the experiments, you must first download the data. You can get it from the following drive link:
To run the experiments, you must first download the data. You can get it from the following drive link: https://drive.google.com/file/d/1cP1W9OluX_lOUWn6jZ9QpEb5qtlopXCP/view?usp=sharing


### MNIST Sum-N
Expand Down Expand Up @@ -42,4 +117,4 @@ python run.py --cuda --n-epochs=10 --seed 1831 --learning-rate 1e-5
```bash
cd experiments/mugen
python run.py --phase=train --train_size=1000 --provenance=damp --seed=1234 --epochs=100 --use_cuda
```
```
108 changes: 108 additions & 0 deletions dolphin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,111 @@ def wrapper(*args, **kwargs):
else:
return d.apply(args, f)
return wrapper


def function(f):
"""
A decorator that lifts a function operating on symbols to work with Distributions.

Usage:
@dolphin.function
def add_and_square(a, b):
return (a + b) ** 2

# Works with regular values
result = add_and_square(2, 3) # Returns 25

# Works with Distributions - applies function to all symbol combinations
d1 = Distribution(probs1, [1, 2, 3])
d2 = Distribution(probs2, [4, 5])
result = add_and_square(d1, d2) # Returns Distribution over [25, 36, 36, 49, 49, 64]

# Mix of Distributions and regular values
result = add_and_square(d1, 10) # Adds 10 to each symbol in d1, then squares

# Conditional application with 'if_' keyword
@dolphin.function
def divide(a, b):
return a / b

result = divide(d1, d2, if_=lambda a, b: b != 0) # Only where b != 0
"""
def wrapper(*args, **kwargs):
condition = kwargs.pop('if_', None)

# If no args are Distributions, call f normally
if not any(isinstance(arg, Distribution) for arg in args):
return f(*args, **kwargs)

# Get a reference Distribution for lifting
ref_dist = next(arg for arg in args if isinstance(arg, Distribution))

# Ensure first arg is a Distribution (lift if needed)
first = args[0]
if not isinstance(first, Distribution):
first = Distribution.from_value(first, ref_dist)

# Wrap f to include any remaining kwargs
if kwargs:
func = lambda *a: f(*a, **kwargs)
else:
func = f

# Apply based on number of arguments
if len(args) == 1:
return first.map(func)
elif len(args) == 2:
# Single remaining arg - __compute_possibilities handles lifting
if condition is not None:
return first.apply_if(args[1], func, condition)
else:
return first.apply(args[1], func)
else:
# Multiple remaining args - lift non-Distributions to Distributions
remaining = [
arg if isinstance(arg, Distribution) else Distribution.from_value(arg, ref_dist)
for arg in args[1:]
]
if condition is not None:
return first.apply_if(remaining, func, condition)
else:
return first.apply(remaining, func)

return wrapper


def distribution(symbols):
"""
A class decorator that wraps a neural network module's forward output
into a Dolphin Distribution.

Usage:
@dolphin.distribution(range(10))
class MNISTNet(nn.Module):
def forward(self, x):
# ... network logic ...
return logits # Will be automatically wrapped into Distribution

Args:
symbols: The symbols to associate with the distribution output.
An iterable (list, range, tuple, etc.) of symbols.

The decorated class's forward method will return a Distribution instead
of raw logits, enabling direct arithmetic operations like:
d1 = model1(x1) # Returns Distribution
d2 = model2(x2) # Returns Distribution
result = d1 + d2 # Direct addition of distributions
"""
def decorator(cls):
original_forward = cls.forward

def new_forward(self, *args, **kwargs):
logits = original_forward(self, *args, **kwargs)
return Distribution(logits, symbols)

cls.forward = new_forward
cls._dolphin_decorated = True

return cls

return decorator
120 changes: 108 additions & 12 deletions dolphin/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,28 @@ def l_not(a: Distribution) -> Distribution:
# assert a.type == np.bool_, "All inputs must have the type `np.bool_`"

return a.__compute_possibilities(a, lambda s1, s2 : not s1)

@staticmethod
def from_value(value, reference: Distribution) -> Distribution:
"""
Create a trivial Distribution from a single value with probability 1.

Args:
value: The value to wrap (any Python object)
reference: A reference Distribution to match batch size and device

Returns:
A Distribution with a single symbol and probability 1, matching
the batch dimensions of the reference distribution.

Usage:
d = Distribution.from_value(5, existing_dist)
# d has symbol [5] with prob 1, same batch size as existing_dist
"""
batch_size = reference.tags.shape[0]
device = reference.tags.device
trivial_probs = torch.ones((batch_size, 1), device=device)
return Distribution(trivial_probs, [value])

def __get_symbols_from_array(self, symbol_list):
if isinstance(symbol_list, np.ndarray) and len(symbol_list.shape) == 1:
Expand Down Expand Up @@ -163,7 +185,7 @@ def __calculate_possibilities(self, dists: List[Distribution], function: Callabl
return Distribution.copy(dist)
all_dists = [dist.sample_top_k(dist.k) for dist in all_dists]

tags_list, combined_src = self.provenance.combine_tag_sources(all_dists[0], all_dists[1])
tags_list, combined_src = self.provenance.combine_tag_sources_multi(all_dists)
num_lists = [len(dist.symbols) for dist in all_dists]
symbol_lists = [dist.symbols for dist in all_dists]
index_combinations = np.indices(num_lists).reshape(len(num_lists), -1).T
Expand All @@ -172,7 +194,7 @@ def __calculate_possibilities(self, dists: List[Distribution], function: Callabl
res_list = [function(*combination) for combination in zip(*args)]
results = self.__get_symbols_from_array(res_list)

final_tags = self.provenance.cartesian_prod(tags_list[0], tags_list[1])
final_tags = self.provenance.cartesian_prod_multi(tags_list)
prod_distribution = Distribution(final_tags, results, dist_as_probs=False, src=combined_src)

if conditional:
Expand All @@ -183,19 +205,30 @@ def __calculate_possibilities(self, dists: List[Distribution], function: Callabl

return d

def __compute_possibilities(self, dist_b: Union[Distribution|np.ndarray|Any], function: Callable, conditional = False) -> Distribution:
def __compute_possibilities(self, dists: Union[List[Union[Distribution, np.ndarray, Any]], Distribution, np.ndarray, Any], function: Callable, conditional = False) -> Distribution:
assert self.provenance is not None, "Provenance not set"
if not isinstance(dist_b, Distribution):
if isinstance(dist_b, np.ndarray):
assert len(dist_b) == len(self.symbols), "Length of symbols must match"
dist_b = Distribution(torch.ones(self.tags.shape[:2], device=self.tags.device), dist_b)
else:
if self.tags.dim() > 1:
dist_b = Distribution(torch.ones((self.tags.shape[0], 1), device=self.tags.device), [dist_b, ])

# Normalize input to a list (handle both list and tuple)
if isinstance(dists, (list, tuple)):
dists = list(dists)
else:
dists = [dists]

# Convert each non-Distribution item to a Distribution
processed_dists = []
for dist_b in dists:
if not isinstance(dist_b, Distribution):
if isinstance(dist_b, np.ndarray):
assert len(dist_b) == len(self.symbols), "Length of symbols must match"
dist_b = Distribution(torch.ones(self.tags.shape[:2], device=self.tags.device), dist_b)
else:
dist_b = Distribution(torch.ones(1, device=self.tags.device), [dist_b, ])
if self.tags.dim() > 1:
dist_b = Distribution(torch.ones((self.tags.shape[0], 1), device=self.tags.device), [dist_b, ])
else:
dist_b = Distribution(torch.ones(1, device=self.tags.device), [dist_b, ])
processed_dists.append(dist_b)

res = self.__calculate_possibilities(dist_b, function, conditional)
res = self.__calculate_possibilities(processed_dists, function, conditional)
return res

def apply(self, distributions: List[Distribution], function: Callable) -> Distribution:
Expand Down Expand Up @@ -313,6 +346,69 @@ def __len__(self) -> int:

def __hash__(self):
return hash(self.id)

# =========================================================================
# PyTorch Integration - Allow Distribution to work with torch functions
# =========================================================================

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
"""
Enable Distribution to work with PyTorch functions by auto-extracting probabilities.

When a Distribution is passed to a torch function (e.g., F.cross_entropy),
this method is called. It extracts probabilities from any Distribution arguments
and passes them to the original function.

Example:
dist = model(x) # Returns Distribution
loss = F.cross_entropy(dist, target) # Works directly!
"""
if kwargs is None:
kwargs = {}

# Convert Distribution args to probability tensors
def convert_arg(arg):
if isinstance(arg, Distribution):
return arg.get_probabilities()
elif isinstance(arg, (list, tuple)):
# Recursively handle nested lists/tuples (e.g., torch.stack([dist1, dist2]))
return type(arg)(convert_arg(a) for a in arg)
return arg

converted_args = tuple(convert_arg(arg) for arg in args)
converted_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}

result = func(*converted_args, **converted_kwargs)
return result

def to(self, device):
"""Move Distribution to a device."""
new_tags = self.tags.to(device)
return Distribution(new_tags, self.symbols, dist_as_probs=False, src=self.src)

def cpu(self):
"""Move Distribution to CPU."""
return self.to('cpu')

def cuda(self, device=None):
"""Move Distribution to CUDA."""
if device is None:
return self.to('cuda')
return self.to(f'cuda:{device}')

def __getattr__(self, name):
"""Delegate unknown attributes to the probability tensor."""
# Avoid infinite recursion by checking for internal attributes
if name in ('tags', 'symbols', 'src', 'provenance'):
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

try:
probs = self.get_probabilities()
attr = getattr(probs, name)
return attr
except AttributeError:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def filter(self, filter_function) -> Distribution:
filtered_indices = [ filter_function(s) for s in self.symbols ]
Expand Down
14 changes: 8 additions & 6 deletions experiments/mnist/sum_n.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from argparse import ArgumentParser
from tqdm import tqdm

import dolphin
from dolphin import Distribution
from dolphin.provenances import get_provenance

Expand Down Expand Up @@ -100,6 +101,7 @@ def mnist_sum_n_loader(data_dir, sum_n, batch_size_train, batch_size_test):
return train_loader, test_loader


@dolphin.distribution(range(10))
class MNISTNet(nn.Module):
def __init__(self):
super(MNISTNet, self).__init__()
Expand All @@ -126,13 +128,13 @@ def __init__(self, db=None):
self.mnist_net = MNISTNet()

def forward(self, x: Tuple[torch.Tensor, ...]):
for i in range(len(x)):
if i == 0:
a = Distribution(self.mnist_net(x[i]), range(10))
else:
a = a + Distribution(self.mnist_net(x[i]), range(10))
# With @dolphin.distribution decorator, mnist_net already returns Distribution
# so we can directly add the outputs without manual wrapping
a = self.mnist_net(x[0])
for i in range(1, len(x)):
a = a + self.mnist_net(x[i])

return a.get_probabilities() # Tensor b x (sum_n*9 + 1)
return a


def bce_loss(output, ground_truth):
Expand Down
Loading