From 7ccec0ca22fa634b8fb7570cc1e6048f6ac572f4 Mon Sep 17 00:00:00 2001 From: Aaditya Naik Date: Thu, 5 Feb 2026 14:36:06 -0500 Subject: [PATCH 1/4] Added decorators for easier interaction with Dolphin --- README.md | 226 ++++++++++++++ dolphin/__init__.py | 108 +++++++ dolphin/distribution.py | 57 +++- experiments/mnist/sum_n.py | 12 +- tests/__init__.py | 1 + tests/test_decorators.py | 569 ++++++++++++++++++++++++++++++++++ tests/test_distribution.py | 615 +++++++++++++++++++++++++++++++++++++ 7 files changed, 1571 insertions(+), 17 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/test_decorators.py create mode 100644 tests/test_distribution.py diff --git a/README.md b/README.md index e7c5e5c..6c3d09e 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,232 @@ To install Dolphin, first clone the repository. Then use pip: pip install -e . ``` +## Tutorial: Writing Dolphin Programs + +This tutorial shows how to write neurosymbolic programs using Dolphin's decorator-based API. We'll build up from basic concepts to a complete working example. + +### Core Concepts + +Dolphin operates on **Distributions** — objects that map symbolic values to probabilities. A neural network outputs logits over possible symbols, and Dolphin lets you perform symbolic computations while tracking probabilities through the computation. + +### Setup + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F +import dolphin +from dolphin import Distribution +from dolphin.provenances import get_provenance + +# Set the provenance (probability computation method) +Distribution.provenance = get_provenance("damp") +``` + +### Step 1: Annotate Neural Networks with `@dolphin.distribution` + +Use `@dolphin.distribution(symbols)` to make a neural network's forward pass automatically return a Distribution instead of raw logits. + +```python +@dolphin.distribution(range(10)) # Symbols are digits 0-9 +class DigitClassifier(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, 3) + self.conv2 = nn.Conv2d(32, 64, 3) + self.fc1 = nn.Linear(1600, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2(x), 2)) + x = x.view(-1, 1600) + x = F.relu(self.fc1(x)) + x = self.fc2(x) + return F.softmax(x, dim=1) # Automatically wrapped into Distribution +``` + +Now when you call `model(image)`, it returns a `Distribution` with symbols `[0, 1, 2, ..., 9]` instead of a tensor. + +### Step 2: Define Symbolic Operations with `@dolphin.function` + +Use `@dolphin.function` to lift regular Python functions to work with Distributions. Write the function as if operating on individual symbols — Dolphin handles the probabilistic computation. + +```python +@dolphin.function +def add(a, b): + return a + b + +@dolphin.function +def multiply(a, b): + return a * b + +@dolphin.function +def is_even(x): + return x % 2 == 0 +``` + +These functions now work seamlessly with both regular values and Distributions: + +```python +# With regular values (works normally) +result = add(3, 5) # Returns 8 + +# With Distributions (computes all possible outcomes) +d1 = model(image1) # Distribution over [0-9] +d2 = model(image2) # Distribution over [0-9] +sum_dist = add(d1, d2) # Distribution over [0-18] +``` + +### Step 3: Compose Operations + +Chain decorated functions to build complex symbolic programs: + +```python +@dolphin.function +def digit_formula(a, b, c): + """Compute a * b + c for digit distributions.""" + return a * b + c + +# Or compose simpler functions +result = add(multiply(d1, d2), d3) +``` + +### Step 4: Extract Probabilities for Training + +Use `.get_probabilities()` to convert back to tensors for loss computation: + +```python +class DigitSumModel(nn.Module): + def __init__(self): + super().__init__() + self.digit_net = DigitClassifier() + + def forward(self, images): + # Each image produces a Distribution over digits + d1 = self.digit_net(images[0]) + d2 = self.digit_net(images[1]) + + # Sum the distributions + result = add(d1, d2) + + # Get probability tensor for loss computation + return result.get_probabilities() + +# Training loop +model = DigitSumModel() +optimizer = torch.optim.Adam(model.parameters()) + +for images, target_sum in dataloader: + output = model(images) # Probabilities over [0-18] + loss = F.cross_entropy(output, target_sum) + loss.backward() + optimizer.step() +``` + +### Complete Example: MNIST Sum + +Here's a complete example that sums N MNIST digit images: + +```python +import torch +import torch.nn as nn +import torch.nn.functional as F +import dolphin +from dolphin import Distribution +from dolphin.provenances import get_provenance + +# Setup provenance +Distribution.provenance = get_provenance("damp") + +# Define the digit classifier with automatic Distribution wrapping +@dolphin.distribution(range(10)) +class MNISTNet(nn.Module): + def __init__(self): + super().__init__() + self.conv1 = nn.Conv2d(1, 32, kernel_size=5) + self.conv2 = nn.Conv2d(32, 64, kernel_size=5) + self.fc1 = nn.Linear(1024, 1024) + self.fc2 = nn.Linear(1024, 10) + + def forward(self, x): + x = F.max_pool2d(self.conv1(x), 2) + x = F.max_pool2d(self.conv2(x), 2) + x = x.view(-1, 1024) + x = F.relu(self.fc1(x)) + x = F.dropout(x, p=0.5, training=self.training) + x = self.fc2(x) + return F.softmax(x, dim=1) + +# Define symbolic addition +@dolphin.function +def add(a, b): + return a + b + +# Model that sums N digits +class MNISTSumN(nn.Module): + def __init__(self): + super().__init__() + self.digit_net = MNISTNet() + + def forward(self, images): + # images is a tuple of N image tensors + result = self.digit_net(images[0]) + for i in range(1, len(images)): + result = add(result, self.digit_net(images[i])) + return result.get_probabilities() + +# Usage +model = MNISTSumN() +images = (img1, img2, img3) # 3 MNIST images +output = model(images) # Probabilities over sums 0-27 +``` + +### Working Directly with Distributions + +You can also work with Distributions directly without decorators: + +```python +# Create distributions manually +probs = torch.tensor([[0.1, 0.3, 0.6]]) +dist = Distribution(probs, ['cat', 'dog', 'bird']) + +# Arithmetic operations +d1 = Distribution(torch.tensor([[0.5, 0.5]]), [1, 2]) +d2 = Distribution(torch.tensor([[0.5, 0.5]]), [10, 20]) + +sum_dist = d1 + d2 # Symbols: [11, 12, 21, 22] +prod_dist = d1 * d2 # Symbols: [10, 20, 20, 40] + +# Map and filter +doubled = d1.map(lambda x: x * 2) # Symbols: [2, 4] +filtered = d1.filter(lambda x: x > 1) # Symbols: [2] + +# Apply custom functions +result = d1.apply(d2, lambda a, b: a ** b) +``` + +### Conditional Operations + +Use the `if_=` parameter for conditional computations: + +```python +@dolphin.function +def safe_divide(a, b): + return a / b + +# Only compute where b != 0 +result = safe_divide(d1, d2, if_=lambda a, b: b != 0) +``` + +### Key Points + +1. **`@dolphin.distribution(symbols)`** — Wrap neural network outputs as Distributions +2. **`@dolphin.function`** — Lift functions to work with Distributions +3. **Distributions track probabilities** — All operations maintain proper probability semantics +4. **`.get_probabilities()`** — Convert back to tensors for training +5. **GPU-accelerated** — All computations run efficiently on GPU + ## Running Experiments To run the experiments, you must first download the data. You can get it from the following drive link: diff --git a/dolphin/__init__.py b/dolphin/__init__.py index 35565c0..5bf573a 100644 --- a/dolphin/__init__.py +++ b/dolphin/__init__.py @@ -21,3 +21,111 @@ def wrapper(*args, **kwargs): else: return d.apply(args, f) return wrapper + + +def function(f): + """ + A decorator that lifts a function operating on symbols to work with Distributions. + + Usage: + @dolphin.function + def add_and_square(a, b): + return (a + b) ** 2 + + # Works with regular values + result = add_and_square(2, 3) # Returns 25 + + # Works with Distributions - applies function to all symbol combinations + d1 = Distribution(probs1, [1, 2, 3]) + d2 = Distribution(probs2, [4, 5]) + result = add_and_square(d1, d2) # Returns Distribution over [25, 36, 36, 49, 49, 64] + + # Mix of Distributions and regular values + result = add_and_square(d1, 10) # Adds 10 to each symbol in d1, then squares + + # Conditional application with 'if_' keyword + @dolphin.function + def divide(a, b): + return a / b + + result = divide(d1, d2, if_=lambda a, b: b != 0) # Only where b != 0 + """ + def wrapper(*args, **kwargs): + condition = kwargs.pop('if_', None) + + # If no args are Distributions, call f normally + if not any(isinstance(arg, Distribution) for arg in args): + return f(*args, **kwargs) + + # Get a reference Distribution for lifting + ref_dist = next(arg for arg in args if isinstance(arg, Distribution)) + + # Ensure first arg is a Distribution (lift if needed) + first = args[0] + if not isinstance(first, Distribution): + first = Distribution.from_value(first, ref_dist) + + # Wrap f to include any remaining kwargs + if kwargs: + func = lambda *a: f(*a, **kwargs) + else: + func = f + + # Apply based on number of arguments + if len(args) == 1: + return first.map(func) + elif len(args) == 2: + # Single remaining arg - __compute_possibilities handles lifting + if condition is not None: + return first.apply_if(args[1], func, condition) + else: + return first.apply(args[1], func) + else: + # Multiple remaining args - lift non-Distributions to Distributions + remaining = [ + arg if isinstance(arg, Distribution) else Distribution.from_value(arg, ref_dist) + for arg in args[1:] + ] + if condition is not None: + return first.apply_if(remaining, func, condition) + else: + return first.apply(remaining, func) + + return wrapper + + +def distribution(symbols): + """ + A class decorator that wraps a neural network module's forward output + into a Dolphin Distribution. + + Usage: + @dolphin.distribution(range(10)) + class MNISTNet(nn.Module): + def forward(self, x): + # ... network logic ... + return logits # Will be automatically wrapped into Distribution + + Args: + symbols: The symbols to associate with the distribution output. + An iterable (list, range, tuple, etc.) of symbols. + + The decorated class's forward method will return a Distribution instead + of raw logits, enabling direct arithmetic operations like: + d1 = model1(x1) # Returns Distribution + d2 = model2(x2) # Returns Distribution + result = d1 + d2 # Direct addition of distributions + """ + def decorator(cls): + original_forward = cls.forward + + def new_forward(self, *args, **kwargs): + logits = original_forward(self, *args, **kwargs) + return Distribution(logits, symbols) + + cls.forward = new_forward + cls._dolphin_decorated = True + + return cls + + return decorator diff --git a/dolphin/distribution.py b/dolphin/distribution.py index 8720fe1..635a518 100644 --- a/dolphin/distribution.py +++ b/dolphin/distribution.py @@ -56,6 +56,28 @@ def l_not(a: Distribution) -> Distribution: # assert a.type == np.bool_, "All inputs must have the type `np.bool_`" return a.__compute_possibilities(a, lambda s1, s2 : not s1) + + @staticmethod + def from_value(value, reference: Distribution) -> Distribution: + """ + Create a trivial Distribution from a single value with probability 1. + + Args: + value: The value to wrap (any Python object) + reference: A reference Distribution to match batch size and device + + Returns: + A Distribution with a single symbol and probability 1, matching + the batch dimensions of the reference distribution. + + Usage: + d = Distribution.from_value(5, existing_dist) + # d has symbol [5] with prob 1, same batch size as existing_dist + """ + batch_size = reference.tags.shape[0] + device = reference.tags.device + trivial_probs = torch.ones((batch_size, 1), device=device) + return Distribution(trivial_probs, [value]) def __get_symbols_from_array(self, symbol_list): if isinstance(symbol_list, np.ndarray) and len(symbol_list.shape) == 1: @@ -163,7 +185,7 @@ def __calculate_possibilities(self, dists: List[Distribution], function: Callabl return Distribution.copy(dist) all_dists = [dist.sample_top_k(dist.k) for dist in all_dists] - tags_list, combined_src = self.provenance.combine_tag_sources(all_dists[0], all_dists[1]) + tags_list, combined_src = self.provenance.combine_tag_sources_multi(all_dists) num_lists = [len(dist.symbols) for dist in all_dists] symbol_lists = [dist.symbols for dist in all_dists] index_combinations = np.indices(num_lists).reshape(len(num_lists), -1).T @@ -172,7 +194,7 @@ def __calculate_possibilities(self, dists: List[Distribution], function: Callabl res_list = [function(*combination) for combination in zip(*args)] results = self.__get_symbols_from_array(res_list) - final_tags = self.provenance.cartesian_prod(tags_list[0], tags_list[1]) + final_tags = self.provenance.cartesian_prod_multi(tags_list) prod_distribution = Distribution(final_tags, results, dist_as_probs=False, src=combined_src) if conditional: @@ -183,19 +205,30 @@ def __calculate_possibilities(self, dists: List[Distribution], function: Callabl return d - def __compute_possibilities(self, dist_b: Union[Distribution|np.ndarray|Any], function: Callable, conditional = False) -> Distribution: + def __compute_possibilities(self, dists: Union[List[Union[Distribution, np.ndarray, Any]], Distribution, np.ndarray, Any], function: Callable, conditional = False) -> Distribution: assert self.provenance is not None, "Provenance not set" - if not isinstance(dist_b, Distribution): - if isinstance(dist_b, np.ndarray): - assert len(dist_b) == len(self.symbols), "Length of symbols must match" - dist_b = Distribution(torch.ones(self.tags.shape[:2], device=self.tags.device), dist_b) - else: - if self.tags.dim() > 1: - dist_b = Distribution(torch.ones((self.tags.shape[0], 1), device=self.tags.device), [dist_b, ]) + + # Normalize input to a list (handle both list and tuple) + if isinstance(dists, (list, tuple)): + dists = list(dists) + else: + dists = [dists] + + # Convert each non-Distribution item to a Distribution + processed_dists = [] + for dist_b in dists: + if not isinstance(dist_b, Distribution): + if isinstance(dist_b, np.ndarray): + assert len(dist_b) == len(self.symbols), "Length of symbols must match" + dist_b = Distribution(torch.ones(self.tags.shape[:2], device=self.tags.device), dist_b) else: - dist_b = Distribution(torch.ones(1, device=self.tags.device), [dist_b, ]) + if self.tags.dim() > 1: + dist_b = Distribution(torch.ones((self.tags.shape[0], 1), device=self.tags.device), [dist_b, ]) + else: + dist_b = Distribution(torch.ones(1, device=self.tags.device), [dist_b, ]) + processed_dists.append(dist_b) - res = self.__calculate_possibilities(dist_b, function, conditional) + res = self.__calculate_possibilities(processed_dists, function, conditional) return res def apply(self, distributions: List[Distribution], function: Callable) -> Distribution: diff --git a/experiments/mnist/sum_n.py b/experiments/mnist/sum_n.py index 326853e..997fb7d 100644 --- a/experiments/mnist/sum_n.py +++ b/experiments/mnist/sum_n.py @@ -14,6 +14,7 @@ from argparse import ArgumentParser from tqdm import tqdm +import dolphin from dolphin import Distribution from dolphin.provenances import get_provenance @@ -100,6 +101,7 @@ def mnist_sum_n_loader(data_dir, sum_n, batch_size_train, batch_size_test): return train_loader, test_loader +@dolphin.distribution(range(10)) class MNISTNet(nn.Module): def __init__(self): super(MNISTNet, self).__init__() @@ -126,11 +128,11 @@ def __init__(self, db=None): self.mnist_net = MNISTNet() def forward(self, x: Tuple[torch.Tensor, ...]): - for i in range(len(x)): - if i == 0: - a = Distribution(self.mnist_net(x[i]), range(10)) - else: - a = a + Distribution(self.mnist_net(x[i]), range(10)) + # With @dolphin.distribution decorator, mnist_net already returns Distribution + # so we can directly add the outputs without manual wrapping + a = self.mnist_net(x[0]) + for i in range(1, len(x)): + a = a + self.mnist_net(x[i]) return a.get_probabilities() # Tensor b x (sum_n*9 + 1) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..c99488f --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +# Dolphin test suite diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 0000000..eeb0ace --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,569 @@ +""" +Test suite for Dolphin decorators: +- @dolphin.distribution(symbols) - class decorator for nn.Module +- @dolphin.function - function decorator for lifting operations +- Distribution.from_value() - static method for creating trivial distributions +""" + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +import dolphin +from dolphin import Distribution +from dolphin.provenances import get_provenance + + +# ============================================================================= +# Fixtures +# ============================================================================= + +@pytest.fixture(autouse=True) +def setup_provenance(): + """Set up provenance before each test.""" + Distribution.provenance = get_provenance("damp") + yield + + +@pytest.fixture +def simple_probs(): + """Simple probability tensor for testing.""" + return torch.tensor([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3]]) + + +@pytest.fixture +def digit_probs(): + """Probability tensor for digit classification (0-9).""" + probs = torch.softmax(torch.randn(4, 10), dim=1) + return probs + + +# ============================================================================= +# Tests for Distribution.from_value() +# ============================================================================= + +class TestFromValue: + """Tests for Distribution.from_value() static method.""" + + def test_from_value_basic(self, simple_probs): + """Test basic from_value functionality.""" + ref_dist = Distribution(simple_probs, ['a', 'b', 'c']) + lifted = Distribution.from_value(42, ref_dist) + + assert isinstance(lifted, Distribution) + assert len(lifted.symbols) == 1 + assert lifted.symbols[0] == 42 + assert lifted.tags.shape[0] == simple_probs.shape[0] # Same batch size + + def test_from_value_string_symbol(self, simple_probs): + """Test from_value with string symbol.""" + ref_dist = Distribution(simple_probs, ['a', 'b', 'c']) + lifted = Distribution.from_value("hello", ref_dist) + + assert lifted.symbols[0] == "hello" + + def test_from_value_tuple_symbol(self, simple_probs): + """Test from_value with tuple symbol.""" + ref_dist = Distribution(simple_probs, ['a', 'b', 'c']) + lifted = Distribution.from_value((1, 2, 3), ref_dist) + + assert lifted.symbols[0] == (1, 2, 3) + + def test_from_value_preserves_device(self, simple_probs): + """Test that from_value preserves device of reference distribution.""" + ref_dist = Distribution(simple_probs, ['a', 'b', 'c']) + lifted = Distribution.from_value(10, ref_dist) + + assert lifted.tags.device == ref_dist.tags.device + + def test_from_value_probability_is_one(self, simple_probs): + """Test that lifted distribution has probability 1.""" + ref_dist = Distribution(simple_probs, ['a', 'b', 'c']) + lifted = Distribution.from_value(5, ref_dist) + + probs = lifted.get_probabilities() + assert torch.allclose(probs, torch.ones_like(probs)) + + +# ============================================================================= +# Tests for @dolphin.function decorator +# ============================================================================= + +class TestFunctionDecorator: + """Tests for @dolphin.function decorator.""" + + def test_function_no_distributions(self): + """Test that function works normally with no distributions.""" + @dolphin.function + def add(a, b): + return a + b + + result = add(3, 5) + assert result == 8 + + def test_function_single_distribution_map(self, simple_probs): + """Test function with single distribution uses map.""" + @dolphin.function + def double(x): + return x * 2 + + dist = Distribution(simple_probs, [1, 2, 3]) + result = double(dist) + + assert isinstance(result, Distribution) + # Symbols should be doubled + assert set(result.symbols) == {2, 4, 6} + + def test_function_two_distributions(self, simple_probs): + """Test function with two distributions.""" + @dolphin.function + def add(a, b): + return a + b + + d1 = Distribution(simple_probs, [1, 2, 3]) + d2 = Distribution(simple_probs, [10, 20, 30]) + result = add(d1, d2) + + assert isinstance(result, Distribution) + # Result should contain all pairwise sums + expected_sums = {1+10, 1+20, 1+30, 2+10, 2+20, 2+30, 3+10, 3+20, 3+30} + assert set(result.symbols) == expected_sums + + def test_function_distribution_and_scalar(self, simple_probs): + """Test function with distribution and scalar.""" + @dolphin.function + def add(a, b): + return a + b + + dist = Distribution(simple_probs, [1, 2, 3]) + result = add(dist, 10) + + assert isinstance(result, Distribution) + assert set(result.symbols) == {11, 12, 13} + + def test_function_scalar_and_distribution(self, simple_probs): + """Test function with scalar first, distribution second.""" + @dolphin.function + def subtract(a, b): + return a - b + + dist = Distribution(simple_probs, [1, 2, 3]) + result = subtract(10, dist) + + assert isinstance(result, Distribution) + # 10 - 1, 10 - 2, 10 - 3 + assert set(result.symbols) == {9, 8, 7} + + def test_function_non_commutative(self): + """Test that argument order is preserved for non-commutative operations.""" + @dolphin.function + def power(base, exp): + return base ** exp + + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [2, 3]) + d2 = Distribution(probs, [2, 3]) + result = power(d1, d2) + + # 2^2=4, 2^3=8, 3^2=9, 3^3=27 + assert set(result.symbols) == {4, 8, 9, 27} + + def test_function_with_condition(self, simple_probs): + """Test function with if_ condition.""" + @dolphin.function + def divide(a, b): + return a / b + + d1 = Distribution(simple_probs, [10, 20, 30]) + d2 = Distribution(simple_probs, [0, 2, 5]) + + result = divide(d1, d2, if_=lambda a, b: b != 0) + + assert isinstance(result, Distribution) + # Should not include divisions by zero + assert 0 not in [s for s in result.symbols if s == float('inf') or s != s] + + def test_function_three_args_chained(self): + """Test function with three arguments using chained operations. + + Note: Distribution.apply only supports combining 2 distributions at a time. + For operations on 3+ values, chain multiple decorated function calls. + """ + @dolphin.function + def multiply(a, b): + return a * b + + @dolphin.function + def add(a, b): + return a + b + + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [10, 20]) + d2 = Distribution(probs, [5, 15]) + + # weighted_sum(a, b, w) = a * w + b * (1 - w), with w = 0.5 + # Can be computed as: add(multiply(d1, 0.5), multiply(d2, 0.5)) + result = add(multiply(d1, 0.5), multiply(d2, 0.5)) + + assert isinstance(result, Distribution) + # 10*0.5 + 5*0.5 = 7.5, 10*0.5 + 15*0.5 = 12.5, etc. + expected = {10*0.5 + 5*0.5, 10*0.5 + 15*0.5, 20*0.5 + 5*0.5, 20*0.5 + 15*0.5} + assert set(result.symbols) == expected + + def test_function_returns_tuple(self): + """Test function that returns a tuple.""" + @dolphin.function + def make_pair(a, b): + return (a, b) + + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, ['x', 'y']) + result = make_pair(d1, d2) + + assert isinstance(result, Distribution) + expected = {(1, 'x'), (1, 'y'), (2, 'x'), (2, 'y')} + assert set(result.symbols) == expected + + def test_function_returns_string(self): + """Test function that returns strings.""" + @dolphin.function + def concat(a, b): + return f"{a}-{b}" + + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, ['a', 'b']) + d2 = Distribution(probs, ['x', 'y']) + result = concat(d1, d2) + + assert isinstance(result, Distribution) + expected = {'a-x', 'a-y', 'b-x', 'b-y'} + assert set(result.symbols) == expected + + def test_function_preserves_batch_dimension(self, digit_probs): + """Test that batch dimension is preserved.""" + @dolphin.function + def add(a, b): + return a + b + + d1 = Distribution(digit_probs, range(10)) + d2 = Distribution(digit_probs, range(10)) + result = add(d1, d2) + + # Batch size should be preserved + assert result.tags.shape[0] == digit_probs.shape[0] + + +# ============================================================================= +# Tests for @dolphin.distribution decorator +# ============================================================================= + +class TestDistributionDecorator: + """Tests for @dolphin.distribution class decorator.""" + + def test_distribution_basic(self): + """Test basic distribution decorator functionality.""" + @dolphin.distribution(range(3)) + class SimpleNet(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 3) + + def forward(self, x): + return F.softmax(self.linear(x), dim=1) + + net = SimpleNet() + x = torch.randn(4, 10) + result = net(x) + + assert isinstance(result, Distribution) + assert list(result.symbols) == [0, 1, 2] + assert result.tags.shape[0] == 4 # Batch size + + def test_distribution_with_list_symbols(self): + """Test distribution decorator with list symbols.""" + @dolphin.distribution(['cat', 'dog', 'bird']) + class Classifier(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 3) + + def forward(self, x): + return F.softmax(self.linear(x), dim=1) + + net = Classifier() + x = torch.randn(2, 5) + result = net(x) + + assert isinstance(result, Distribution) + assert list(result.symbols) == ['cat', 'dog', 'bird'] + + def test_distribution_preserves_training_mode(self): + """Test that decorator preserves training/eval mode.""" + @dolphin.distribution(range(5)) + class NetWithDropout(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(10, 5) + self.dropout = nn.Dropout(0.5) + + def forward(self, x): + x = self.dropout(self.linear(x)) + return F.softmax(x, dim=1) + + net = NetWithDropout() + x = torch.randn(4, 10) + + # Test in training mode + net.train() + result_train = net(x) + assert isinstance(result_train, Distribution) + + # Test in eval mode + net.eval() + result_eval = net(x) + assert isinstance(result_eval, Distribution) + + def test_distribution_marked_as_decorated(self): + """Test that decorated class has _dolphin_decorated attribute.""" + @dolphin.distribution(range(3)) + class MarkedNet(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 3) + + def forward(self, x): + return F.softmax(self.linear(x), dim=1) + + assert hasattr(MarkedNet, '_dolphin_decorated') + assert MarkedNet._dolphin_decorated is True + + def test_distribution_with_tuple_symbols(self): + """Test distribution decorator with tuple symbols.""" + symbols = [(0, 0), (0, 1), (1, 0), (1, 1)] + + @dolphin.distribution(symbols) + class PairNet(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 4) + + def forward(self, x): + return F.softmax(self.linear(x), dim=1) + + net = PairNet() + x = torch.randn(2, 5) + result = net(x) + + assert isinstance(result, Distribution) + assert len(result.symbols) == 4 + + +# ============================================================================= +# Integration Tests +# ============================================================================= + +class TestIntegration: + """Integration tests combining multiple decorators.""" + + def test_decorated_nets_with_function(self): + """Test using @dolphin.function with @dolphin.distribution decorated nets.""" + @dolphin.distribution(range(10)) + class DigitNet(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(20, 10) + + def forward(self, x): + return F.softmax(self.linear(x), dim=1) + + @dolphin.function + def add(a, b): + return a + b + + net = DigitNet() + x1 = torch.randn(4, 20) + x2 = torch.randn(4, 20) + + d1 = net(x1) + d2 = net(x2) + + result = add(d1, d2) + + assert isinstance(result, Distribution) + # Sum of two digits (0-9) should give 0-18 + assert min(result.symbols) >= 0 + assert max(result.symbols) <= 18 + + def test_chained_operations(self): + """Test chaining multiple decorated functions.""" + @dolphin.function + def add(a, b): + return a + b + + @dolphin.function + def multiply(a, b): + return a * b + + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, [3, 4]) + d3 = Distribution(probs, [10, 20]) + + # (d1 + d2) * d3 + sum_result = add(d1, d2) + final_result = multiply(sum_result, d3) + + assert isinstance(final_result, Distribution) + # (1+3)*10=40, (1+3)*20=80, (1+4)*10=50, (1+4)*20=100, etc. + expected = {(1+3)*10, (1+3)*20, (1+4)*10, (1+4)*20, + (2+3)*10, (2+3)*20, (2+4)*10, (2+4)*20} + assert set(final_result.symbols) == expected + + def test_mnist_sum_pattern(self): + """Test the MNIST sum pattern from the experiment.""" + @dolphin.distribution(range(10)) + class SimpleDigitNet(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(5, 10) + + def forward(self, x): + return F.softmax(self.linear(x), dim=1) + + @dolphin.function + def add(a, b): + return a + b + + net = SimpleDigitNet() + + # Simulate summing n digit predictions + n = 3 + batch_size = 2 + inputs = [torch.randn(batch_size, 5) for _ in range(n)] + + # Sum all predictions + result = net(inputs[0]) + for i in range(1, n): + result = add(result, net(inputs[i])) + + assert isinstance(result, Distribution) + # Sum of n digits (each 0-9) should be in range [0, 9*n] + assert min(result.symbols) >= 0 + assert max(result.symbols) <= 9 * n + + +# ============================================================================= +# Edge Cases +# ============================================================================= + +class TestEdgeCases: + """Edge case tests.""" + + def test_function_with_zero_in_symbols(self, simple_probs): + """Test handling of zero in symbols.""" + @dolphin.function + def add(a, b): + return a + b + + d1 = Distribution(simple_probs, [0, 1, 2]) + d2 = Distribution(simple_probs, [0, 10, 20]) + result = add(d1, d2) + + assert 0 in result.symbols # 0 + 0 = 0 + + def test_function_with_negative_symbols(self, simple_probs): + """Test handling of negative symbols.""" + @dolphin.function + def add(a, b): + return a + b + + d1 = Distribution(simple_probs, [-1, 0, 1]) + d2 = Distribution(simple_probs, [-10, 0, 10]) + result = add(d1, d2) + + assert -11 in result.symbols # -1 + -10 + assert 11 in result.symbols # 1 + 10 + + def test_function_with_float_symbols(self, simple_probs): + """Test handling of float symbols.""" + @dolphin.function + def add(a, b): + return a + b + + d1 = Distribution(simple_probs, [0.1, 0.2, 0.3]) + d2 = Distribution(simple_probs, [1.0, 2.0, 3.0]) + result = add(d1, d2) + + assert isinstance(result, Distribution) + assert len(result.symbols) == 9 # 3 * 3 combinations + + def test_function_with_boolean_result(self): + """Test function returning boolean.""" + @dolphin.function + def greater_than(a, b): + return a > b + + probs3 = torch.tensor([[0.33, 0.33, 0.34]]) + probs2 = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs3, [1, 5, 10]) + d2 = Distribution(probs2, [3, 6]) + result = greater_than(d1, d2) + + assert isinstance(result, Distribution) + assert True in result.symbols or False in result.symbols + + def test_function_single_symbol_distribution(self): + """Test with single-symbol distribution.""" + @dolphin.function + def add(a, b): + return a + b + + probs = torch.tensor([[1.0]]) + d1 = Distribution(probs, [5]) + d2 = Distribution(probs, [10]) + result = add(d1, d2) + + assert isinstance(result, Distribution) + assert list(result.symbols) == [15] + + def test_distribution_empty_forward_args(self): + """Test distribution decorator with no forward args (uses internal state).""" + @dolphin.distribution(['a', 'b']) + class StatefulNet(nn.Module): + def __init__(self): + super().__init__() + self.state = torch.randn(1, 2) + + def forward(self): + return F.softmax(self.state, dim=1) + + net = StatefulNet() + result = net() + + assert isinstance(result, Distribution) + assert list(result.symbols) == ['a', 'b'] + + def test_function_kwargs_preserved(self): + """Test that kwargs (other than if_) are preserved.""" + @dolphin.function + def format_result(a, b, prefix=""): + return f"{prefix}{a + b}" + + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + result = format_result(d1, 10, prefix="sum=") + + assert isinstance(result, Distribution) + assert "sum=11" in result.symbols + assert "sum=12" in result.symbols + + +# ============================================================================= +# Run tests +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/test_distribution.py b/tests/test_distribution.py new file mode 100644 index 0000000..906bb32 --- /dev/null +++ b/tests/test_distribution.py @@ -0,0 +1,615 @@ +""" +Test suite for core Dolphin Distribution functionality. +Tests basic operations, arithmetic, comparisons, and utility methods. +""" + +import pytest +import torch +import numpy as np + +from dolphin import Distribution +from dolphin.provenances import get_provenance + + +# ============================================================================= +# Fixtures +# ============================================================================= + +@pytest.fixture(autouse=True) +def setup_provenance(): + """Set up provenance before each test.""" + Distribution.provenance = get_provenance("damp") + yield + + +@pytest.fixture +def simple_dist(): + """Simple distribution for testing.""" + probs = torch.tensor([[0.2, 0.3, 0.5]]) + return Distribution(probs, [1, 2, 3]) + + +@pytest.fixture +def batch_dist(): + """Batched distribution for testing.""" + probs = torch.tensor([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3]]) + return Distribution(probs, [10, 20, 30]) + + +# ============================================================================= +# Tests for Distribution Creation +# ============================================================================= + +class TestDistributionCreation: + """Tests for creating Distribution objects.""" + + def test_create_with_list_symbols(self): + """Test creating distribution with list symbols.""" + probs = torch.tensor([[0.3, 0.7]]) + dist = Distribution(probs, ['a', 'b']) + + assert len(dist.symbols) == 2 + assert list(dist.symbols) == ['a', 'b'] + + def test_create_with_range_symbols(self): + """Test creating distribution with range symbols.""" + probs = torch.tensor([[0.25, 0.25, 0.25, 0.25]]) + dist = Distribution(probs, range(4)) + + assert len(dist.symbols) == 4 + assert list(dist.symbols) == [0, 1, 2, 3] + + def test_create_with_tuple_symbols(self): + """Test creating distribution with tuple symbols.""" + probs = torch.tensor([[0.5, 0.5]]) + dist = Distribution(probs, [(0, 0), (1, 1)]) + + assert len(dist.symbols) == 2 + + def test_create_batched_distribution(self): + """Test creating batched distribution.""" + probs = torch.tensor([[0.3, 0.7], [0.6, 0.4], [0.5, 0.5]]) + dist = Distribution(probs, ['x', 'y']) + + assert len(dist) == 3 # Batch size + assert len(dist.symbols) == 2 + + def test_create_requires_matching_dimensions(self): + """Test that probs and symbols must match.""" + probs = torch.tensor([[0.3, 0.7]]) + with pytest.raises(AssertionError): + Distribution(probs, [1, 2, 3]) # 3 symbols but 2 probs + + def test_get_probabilities(self, simple_dist): + """Test getting probabilities from distribution.""" + probs = simple_dist.get_probabilities() + + assert isinstance(probs, torch.Tensor) + assert probs.shape[-1] == 3 + + +# ============================================================================= +# Tests for Arithmetic Operations +# ============================================================================= + +class TestArithmeticOperations: + """Tests for arithmetic operations on distributions.""" + + def test_addition_two_distributions(self): + """Test adding two distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, [10, 20]) + + result = d1 + d2 + + assert isinstance(result, Distribution) + assert set(result.symbols) == {11, 12, 21, 22} + + def test_addition_with_scalar(self): + """Test adding distribution and scalar.""" + probs = torch.tensor([[0.5, 0.5]]) + d = Distribution(probs, [1, 2]) + + result = d + 10 + + assert set(result.symbols) == {11, 12} + + def test_subtraction(self): + """Test subtracting distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [10, 20]) + d2 = Distribution(probs, [3, 5]) + + result = d1 - d2 + + assert set(result.symbols) == {10-3, 10-5, 20-3, 20-5} + + def test_multiplication(self): + """Test multiplying distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [2, 3]) + d2 = Distribution(probs, [4, 5]) + + result = d1 * d2 + + assert set(result.symbols) == {8, 10, 12, 15} + + def test_division(self): + """Test dividing distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [10, 20]) + d2 = Distribution(probs, [2, 5]) + + result = d1 / d2 + + assert set(result.symbols) == {10/2, 10/5, 20/2, 20/5} + + def test_floor_division(self): + """Test floor dividing distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [10, 20]) + d2 = Distribution(probs, [3, 7]) + + result = d1 // d2 + + assert set(result.symbols) == {10//3, 10//7, 20//3, 20//7} + + def test_modulo(self): + """Test modulo operation on distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [10, 17]) + d2 = Distribution(probs, [3, 5]) + + result = d1 % d2 + + assert set(result.symbols) == {10%3, 10%5, 17%3, 17%5} + + def test_power(self): + """Test power operation on distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [2, 3]) + d2 = Distribution(probs, [2, 3]) + + result = d1 ** d2 + + assert set(result.symbols) == {4, 8, 9, 27} + + +# ============================================================================= +# Tests for Comparison Operations +# ============================================================================= + +class TestComparisonOperations: + """Tests for comparison operations on distributions.""" + + def test_equality(self): + """Test equality comparison.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, [1, 3]) + + result = d1 == d2 + + assert isinstance(result, Distribution) + assert True in result.symbols # 1 == 1 + assert False in result.symbols # 1 != 3, 2 != 1, 2 != 3 + + def test_inequality(self): + """Test inequality comparison.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, [1, 3]) + + result = d1 != d2 + + assert True in result.symbols + assert False in result.symbols + + def test_greater_than(self): + """Test greater than comparison.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [5, 10]) + d2 = Distribution(probs, [3, 8]) + + result = d1 > d2 + + # 5>3=T, 5>8=F, 10>3=T, 10>8=T + assert True in result.symbols + assert False in result.symbols + + def test_greater_equal(self): + """Test greater than or equal comparison.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [5, 10]) + d2 = Distribution(probs, [5, 8]) + + result = d1 >= d2 + + assert True in result.symbols # 5>=5, 10>=5, 10>=8 + + def test_less_than(self): + """Test less than comparison.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [3, 8]) + d2 = Distribution(probs, [5, 10]) + + result = d1 < d2 + + assert True in result.symbols + assert False in result.symbols + + def test_less_equal(self): + """Test less than or equal comparison.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [5, 8]) + d2 = Distribution(probs, [5, 10]) + + result = d1 <= d2 + + assert True in result.symbols + + +# ============================================================================= +# Tests for Logical Operations +# ============================================================================= + +class TestLogicalOperations: + """Tests for logical operations on distributions.""" + + def test_logical_and(self): + """Test logical AND operation.""" + probs1 = torch.tensor([[0.7, 0.3]]) + probs2 = torch.tensor([[0.6, 0.4]]) + d1 = Distribution(probs1, ['a', 'b']) + d2 = Distribution(probs2, ['a', 'c']) + + result = d1 & d2 + + assert isinstance(result, Distribution) + # Union of symbols + assert set(result.symbols) == {'a', 'b', 'c'} + + def test_logical_or(self): + """Test logical OR operation.""" + probs1 = torch.tensor([[0.7, 0.3]]) + probs2 = torch.tensor([[0.6, 0.4]]) + d1 = Distribution(probs1, ['a', 'b']) + d2 = Distribution(probs2, ['a', 'c']) + + result = d1 | d2 + + assert isinstance(result, Distribution) + assert set(result.symbols) == {'a', 'b', 'c'} + + def test_logical_not(self): + """Test logical NOT (invert) operation.""" + probs = torch.tensor([[0.7, 0.3]]) + d = Distribution(probs, ['a', 'b']) + + result = ~d + + assert isinstance(result, Distribution) + assert result.inverted != d.inverted + + def test_static_l_and(self): + """Test static l_and method.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [True, False]) + d2 = Distribution(probs, [True, False]) + + result = Distribution.l_and(d1, d2) + + assert isinstance(result, Distribution) + # T and T = T, T and F = F, F and T = F, F and F = F + assert True in result.symbols + assert False in result.symbols + + def test_static_l_or(self): + """Test static l_or method.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [True, False]) + d2 = Distribution(probs, [True, False]) + + result = Distribution.l_or(d1, d2) + + assert isinstance(result, Distribution) + assert True in result.symbols + assert False in result.symbols + + def test_static_l_not(self): + """Test static l_not method.""" + probs = torch.tensor([[0.5, 0.5]]) + d = Distribution(probs, [True, False]) + + result = Distribution.l_not(d) + + assert isinstance(result, Distribution) + assert True in result.symbols + assert False in result.symbols + + +# ============================================================================= +# Tests for Map, Filter, Apply +# ============================================================================= + +class TestMapFilterApply: + """Tests for map, filter, and apply operations.""" + + def test_map_function(self, simple_dist): + """Test mapping a function over symbols.""" + result = simple_dist.map(lambda x: x * 2) + + assert isinstance(result, Distribution) + assert set(result.symbols) == {2, 4, 6} + + def test_map_to_string(self, simple_dist): + """Test mapping symbols to strings.""" + result = simple_dist.map(lambda x: f"val_{x}") + + assert set(result.symbols) == {"val_1", "val_2", "val_3"} + + def test_filter_function(self, simple_dist): + """Test filtering symbols.""" + result = simple_dist.filter(lambda x: x > 1) + + assert isinstance(result, Distribution) + assert set(result.symbols) == {2, 3} + + def test_filter_all(self, simple_dist): + """Test filtering that keeps all symbols.""" + result = simple_dist.filter(lambda x: x > 0) + + assert set(result.symbols) == {1, 2, 3} + + def test_apply_two_distributions(self): + """Test apply with two distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, [10, 20]) + + result = d1.apply(d2, lambda a, b: a + b) + + assert set(result.symbols) == {11, 12, 21, 22} + + def test_apply_with_condition(self): + """Test apply_if with condition.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [10, 20]) + d2 = Distribution(probs, [0, 5]) + + result = d1.apply_if(d2, lambda a, b: a / b, lambda a, b: b != 0) + + assert isinstance(result, Distribution) + # Only divisions where b != 0 + assert 10/5 in result.symbols + assert 20/5 in result.symbols + + def test_apply_multiple_distributions(self): + """Test apply with multiple distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, [10, 20]) + d3 = Distribution(probs, [100, 200]) + + result = d1.apply([d2, d3], lambda a, b, c: a + b + c) + + expected = {a + b + c for a in [1,2] for b in [10,20] for c in [100,200]} + assert set(result.symbols) == expected + + +# ============================================================================= +# Tests for Indexing and Iteration +# ============================================================================= + +class TestIndexingIteration: + """Tests for indexing and iteration over distributions.""" + + def test_len(self, batch_dist): + """Test getting length (batch size) of distribution.""" + assert len(batch_dist) == 2 + + def test_getitem_single_index(self, batch_dist): + """Test indexing with single integer.""" + item = batch_dist[0] + + assert isinstance(item, Distribution) + assert len(item) == 1 + + def test_getitem_slice(self, batch_dist): + """Test slicing distribution.""" + items = batch_dist[0:2] + + assert isinstance(items, Distribution) + + def test_iteration(self, batch_dist): + """Test iterating over distribution.""" + items = list(batch_dist) + + assert len(items) == 2 + for item in items: + assert isinstance(item, Distribution) + + +# ============================================================================= +# Tests for Utility Methods +# ============================================================================= + +class TestUtilityMethods: + """Tests for utility methods.""" + + def test_copy(self, simple_dist): + """Test copying a distribution.""" + copy = Distribution.copy(simple_dist) + + assert isinstance(copy, Distribution) + assert set(copy.symbols) == set(simple_dist.symbols) + + def test_softmax(self): + """Test softmax operation.""" + probs = torch.tensor([[1.0, 2.0, 3.0]]) + dist = Distribution(probs, ['a', 'b', 'c']) + + result = dist.softmax() + + assert isinstance(result, Distribution) + result_probs = result.get_probabilities() + assert torch.allclose(result_probs.sum(dim=-1), torch.tensor([1.0]), atol=1e-5) + + def test_drop_symbol(self, simple_dist): + """Test dropping a symbol.""" + result = simple_dist.drop_symbol(2) + + assert 2 not in result.symbols + assert set(result.symbols) == {1, 3} + + def test_map_symbols(self): + """Test mapping to new symbol set.""" + probs = torch.tensor([[0.3, 0.4, 0.3]]) + dist = Distribution(probs, [1, 2, 3]) + + result = dist.map_symbols([1, 2, 3, 4, 5]) + + assert len(result.symbols) == 5 + # Original symbols should have their probabilities preserved + + def test_diff(self): + """Test diff operation.""" + probs3 = torch.tensor([[0.33, 0.33, 0.34]]) + d1 = Distribution(probs3, [1, 2, 3]) + d2 = Distribution(probs3, [2, 3, 4]) + + result = d1.diff(d2) + + # Should contain symbols in d1 but not in d2 + assert 1 in result.symbols + assert 2 not in result.symbols + assert 3 not in result.symbols + + def test_repr(self, simple_dist): + """Test string representation.""" + repr_str = repr(simple_dist) + + assert "Symbols" in repr_str + assert "Distribution" in repr_str + + def test_hash(self, simple_dist): + """Test that distributions are hashable.""" + h = hash(simple_dist) + + assert isinstance(h, int) + + def test_sample_top_k(self): + """Test sampling top k symbols.""" + probs = torch.tensor([[0.1, 0.2, 0.3, 0.4]]) + dist = Distribution(probs, [1, 2, 3, 4]) + + result = dist.sample_top_k(2) + + assert isinstance(result, Distribution) + assert len(result.symbols) <= 2 + + +# ============================================================================= +# Tests for Static Methods +# ============================================================================= + +class TestStaticMethods: + """Tests for static methods.""" + + def test_from_value_integer(self, simple_dist): + """Test from_value with integer.""" + lifted = Distribution.from_value(42, simple_dist) + + assert isinstance(lifted, Distribution) + assert len(lifted.symbols) == 1 + assert lifted.symbols[0] == 42 + + def test_from_value_string(self, simple_dist): + """Test from_value with string.""" + lifted = Distribution.from_value("hello", simple_dist) + + assert lifted.symbols[0] == "hello" + + def test_from_value_tuple(self, simple_dist): + """Test from_value with tuple.""" + lifted = Distribution.from_value((1, 2, 3), simple_dist) + + assert lifted.symbols[0] == (1, 2, 3) + + def test_stack_distributions(self): + """Test stacking distributions.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, [1, 2]) + + result = Distribution.stack([d1, d2]) + + assert isinstance(result, Distribution) + assert len(result) == 2 + + +# ============================================================================= +# Tests for Edge Cases +# ============================================================================= + +class TestEdgeCases: + """Tests for edge cases.""" + + def test_single_symbol_distribution(self): + """Test distribution with single symbol.""" + probs = torch.tensor([[1.0]]) + dist = Distribution(probs, [42]) + + assert len(dist.symbols) == 1 + result = dist + dist + assert set(result.symbols) == {84} + + def test_operations_preserve_batch_size(self): + """Test that operations preserve batch dimension.""" + probs = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.6, 0.4]]) + d1 = Distribution(probs, [1, 2]) + d2 = Distribution(probs, [10, 20]) + + result = d1 + d2 + + assert result.tags.shape[0] == 3 + + def test_float_symbols(self): + """Test distribution with float symbols.""" + probs = torch.tensor([[0.5, 0.5]]) + dist = Distribution(probs, [0.5, 1.5]) + + result = dist * 2 + + assert set(result.symbols) == {1.0, 3.0} + + def test_negative_symbols(self): + """Test distribution with negative symbols.""" + probs = torch.tensor([[0.5, 0.5]]) + dist = Distribution(probs, [-5, 5]) + + result = dist + dist + + assert -10 in result.symbols + assert 0 in result.symbols + assert 10 in result.symbols + + def test_none_symbol_handling(self): + """Test handling of None in conditional operations.""" + probs = torch.tensor([[0.5, 0.5]]) + d1 = Distribution(probs, [10, 20]) + d2 = Distribution(probs, [0, 5]) + + # apply_if returns None for failed conditions, then drop_symbol removes them + result = d1.apply_if(d2, lambda a, b: a / b if b != 0 else None, lambda a, b: b != 0) + + assert isinstance(result, Distribution) + + +# ============================================================================= +# Run tests +# ============================================================================= + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From b0a4815d4db61d678e250a1425e03bdb521108c2 Mon Sep 17 00:00:00 2001 From: Aaditya Naik Date: Thu, 5 Feb 2026 15:34:44 -0500 Subject: [PATCH 2/4] Added detailed tutorial for Dolphin --- README.md | 151 ++++++++++++++++++++----------------- dolphin/distribution.py | 63 ++++++++++++++++ experiments/mnist/sum_n.py | 2 +- tests/test_distribution.py | 149 ++++++++++++++++++++++++++++++++++++ 4 files changed, 293 insertions(+), 72 deletions(-) diff --git a/README.md b/README.md index 6c3d09e..fa5638c 100644 --- a/README.md +++ b/README.md @@ -100,94 +100,103 @@ def digit_formula(a, b, c): result = add(multiply(d1, d2), d3) ``` -### Step 4: Extract Probabilities for Training +### Step 4: Build a Trainable Model -Use `.get_probabilities()` to convert back to tensors for loss computation: +Wrap your symbolic computation in an `nn.Module`. The forward pass returns a Distribution, which you convert to probabilities at the end: ```python -class DigitSumModel(nn.Module): - def __init__(self): +class MNISTSumN(nn.Module): + def __init__(self, n_digits=2): super().__init__() - self.digit_net = DigitClassifier() + self.n_digits = n_digits + self.digit_net = DigitClassifier() # From Step 1 def forward(self, images): - # Each image produces a Distribution over digits - d1 = self.digit_net(images[0]) - d2 = self.digit_net(images[1]) - - # Sum the distributions - result = add(d1, d2) - - # Get probability tensor for loss computation - return result.get_probabilities() - -# Training loop -model = DigitSumModel() -optimizer = torch.optim.Adam(model.parameters()) - -for images, target_sum in dataloader: - output = model(images) # Probabilities over [0-18] - loss = F.cross_entropy(output, target_sum) - loss.backward() - optimizer.step() + # images: (batch_size, n_digits, 1, 28, 28) + result = self.digit_net(images[:, 0]) + for i in range(1, self.n_digits): + result = add(result, self.digit_net(images[:, i])) + return result.get_probabilities() # Convert to tensor for loss ``` -### Complete Example: MNIST Sum +The model takes a batch of image tuples and returns a probability tensor over possible sums. For 2 digits, the output shape is `(batch_size, 19)` — probabilities for sums 0 through 18. + +### Step 5: Compute Loss -Here's a complete example that sums N MNIST digit images: +Use `.get_probabilities()` to get a standard PyTorch tensor, then apply any loss function: ```python -import torch -import torch.nn as nn -import torch.nn.functional as F -import dolphin -from dolphin import Distribution -from dolphin.provenances import get_provenance +model = MNISTSumN(n_digits=2) +images = batch_of_image_pairs # (batch, 2, 1, 28, 28) +target_sums = labels # (batch,) with values in [0, 18] -# Setup provenance -Distribution.provenance = get_provenance("damp") +output = model(images) # (batch, 19) probability tensor +loss = F.cross_entropy(output, target_sums) +``` -# Define the digit classifier with automatic Distribution wrapping -@dolphin.distribution(range(10)) -class MNISTNet(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 32, kernel_size=5) - self.conv2 = nn.Conv2d(32, 64, kernel_size=5) - self.fc1 = nn.Linear(1024, 1024) - self.fc2 = nn.Linear(1024, 10) +The output of `.get_probabilities()` is a regular `torch.Tensor` with `requires_grad=True`, so it works with any PyTorch loss function. - def forward(self, x): - x = F.max_pool2d(self.conv1(x), 2) - x = F.max_pool2d(self.conv2(x), 2) - x = x.view(-1, 1024) - x = F.relu(self.fc1(x)) - x = F.dropout(x, p=0.5, training=self.training) - x = self.fc2(x) - return F.softmax(x, dim=1) +### Step 6: Backpropagation -# Define symbolic addition -@dolphin.function -def add(a, b): - return a + b +Call `loss.backward()` as usual. Gradients flow through the symbolic operations back to the neural network: -# Model that sums N digits -class MNISTSumN(nn.Module): - def __init__(self): - super().__init__() - self.digit_net = MNISTNet() +```python +optimizer.zero_grad() +output = model(images) +loss = F.cross_entropy(output, target_sums) +loss.backward() # Gradients propagate through add() back to DigitClassifier +optimizer.step() +``` - def forward(self, images): - # images is a tuple of N image tensors - result = self.digit_net(images[0]) - for i in range(1, len(images)): - result = add(result, self.digit_net(images[i])) - return result.get_probabilities() - -# Usage -model = MNISTSumN() -images = (img1, img2, img3) # 3 MNIST images -output = model(images) # Probabilities over sums 0-27 +All Dolphin operations (addition, multiplication, composition) are implemented with differentiable tensor operations, so the standard PyTorch autograd machinery works without modification. + +### Step 7: Training Loop + +Put it together into a standard training loop: + +```python +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model = MNISTSumN(n_digits=2).to(device) +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + +for epoch in range(10): + model.train() + total_loss = 0 + + for images, target_sums in train_loader: + images = images.to(device) + target_sums = target_sums.to(device) + + optimizer.zero_grad() + output = model(images) + loss = F.cross_entropy(output, target_sums) + loss.backward() + optimizer.step() + + total_loss += loss.item() + + print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}") +``` + +### Step 8: Evaluation + +Evaluate using `argmax` to get the most likely sum: + +```python +model.eval() +correct = 0 + +with torch.no_grad(): + for images, target_sums in test_loader: + images = images.to(device) + target_sums = target_sums.to(device) + + output = model(images) + predictions = output.argmax(dim=1) + correct += (predictions == target_sums).sum().item() + +accuracy = correct / len(test_loader.dataset) +print(f"Test Accuracy: {accuracy:.2%}") ``` ### Working Directly with Distributions diff --git a/dolphin/distribution.py b/dolphin/distribution.py index 635a518..227f419 100644 --- a/dolphin/distribution.py +++ b/dolphin/distribution.py @@ -346,6 +346,69 @@ def __len__(self) -> int: def __hash__(self): return hash(self.id) + + # ========================================================================= + # PyTorch Integration - Allow Distribution to work with torch functions + # ========================================================================= + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """ + Enable Distribution to work with PyTorch functions by auto-extracting probabilities. + + When a Distribution is passed to a torch function (e.g., F.cross_entropy), + this method is called. It extracts probabilities from any Distribution arguments + and passes them to the original function. + + Example: + dist = model(x) # Returns Distribution + loss = F.cross_entropy(dist, target) # Works directly! + """ + if kwargs is None: + kwargs = {} + + # Convert Distribution args to probability tensors + def convert_arg(arg): + if isinstance(arg, Distribution): + return arg.get_probabilities() + elif isinstance(arg, (list, tuple)): + # Recursively handle nested lists/tuples (e.g., torch.stack([dist1, dist2])) + return type(arg)(convert_arg(a) for a in arg) + return arg + + converted_args = tuple(convert_arg(arg) for arg in args) + converted_kwargs = {k: convert_arg(v) for k, v in kwargs.items()} + + result = func(*converted_args, **converted_kwargs) + return result + + def to(self, device): + """Move Distribution to a device.""" + new_tags = self.tags.to(device) + return Distribution(new_tags, self.symbols, dist_as_probs=False, src=self.src) + + def cpu(self): + """Move Distribution to CPU.""" + return self.to('cpu') + + def cuda(self, device=None): + """Move Distribution to CUDA.""" + if device is None: + return self.to('cuda') + return self.to(f'cuda:{device}') + + def __getattr__(self, name): + """Delegate unknown attributes to the probability tensor.""" + # Avoid infinite recursion by checking for internal attributes + if name in ('tags', 'symbols', 'src', 'provenance'): + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + try: + probs = self.get_probabilities() + attr = getattr(probs, name) + return attr + except AttributeError: + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def filter(self, filter_function) -> Distribution: filtered_indices = [ filter_function(s) for s in self.symbols ] diff --git a/experiments/mnist/sum_n.py b/experiments/mnist/sum_n.py index 997fb7d..4f62fd9 100644 --- a/experiments/mnist/sum_n.py +++ b/experiments/mnist/sum_n.py @@ -134,7 +134,7 @@ def forward(self, x: Tuple[torch.Tensor, ...]): for i in range(1, len(x)): a = a + self.mnist_net(x[i]) - return a.get_probabilities() # Tensor b x (sum_n*9 + 1) + return a def bce_loss(output, ground_truth): diff --git a/tests/test_distribution.py b/tests/test_distribution.py index 906bb32..3545e5e 100644 --- a/tests/test_distribution.py +++ b/tests/test_distribution.py @@ -549,6 +549,155 @@ def test_stack_distributions(self): assert len(result) == 2 +# ============================================================================= +# Tests for PyTorch Integration (__torch_function__) +# ============================================================================= + +class TestPyTorchIntegration: + """Tests for Distribution working with PyTorch functions.""" + + def test_cross_entropy_direct(self): + """Test Distribution works directly with F.cross_entropy.""" + import torch.nn.functional as F + + probs = torch.tensor([[0.1, 0.2, 0.7], [0.8, 0.1, 0.1]]) + dist = Distribution(probs, [0, 1, 2]) + target = torch.tensor([2, 0]) + + # Should work directly without calling get_probabilities() + loss = F.cross_entropy(dist, target) + + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 # scalar loss + + def test_nll_loss_direct(self): + """Test Distribution works directly with F.nll_loss.""" + import torch.nn.functional as F + + # Use log_softmax style probs + logits = torch.tensor([[0.1, 0.2, 0.7], [0.8, 0.1, 0.1]]) + log_probs = F.log_softmax(logits, dim=-1) + dist = Distribution(log_probs, [0, 1, 2]) + target = torch.tensor([2, 0]) + + loss = F.nll_loss(dist, target) + + assert isinstance(loss, torch.Tensor) + + def test_softmax_function(self): + """Test torch.softmax works on Distribution.""" + logits = torch.tensor([[1.0, 2.0, 3.0]]) + dist = Distribution(logits, ['a', 'b', 'c']) + + result = torch.softmax(dist, dim=-1) + + assert isinstance(result, torch.Tensor) + assert torch.allclose(result.sum(dim=-1), torch.tensor([1.0])) + + def test_log_softmax_function(self): + """Test torch.log_softmax works on Distribution.""" + logits = torch.tensor([[1.0, 2.0, 3.0]]) + dist = Distribution(logits, ['a', 'b', 'c']) + + result = torch.log_softmax(dist, dim=-1) + + assert isinstance(result, torch.Tensor) + assert (result <= 0).all() # log probs are negative or zero + + def test_torch_sum(self): + """Test torch.sum works on Distribution.""" + probs = torch.tensor([[0.2, 0.3, 0.5]]) + dist = Distribution(probs, [1, 2, 3]) + + result = torch.sum(dist) + + assert isinstance(result, torch.Tensor) + assert torch.isclose(result, torch.tensor(1.0)) + + def test_torch_mean(self): + """Test torch.mean works on Distribution.""" + probs = torch.tensor([[0.2, 0.3, 0.5]]) + dist = Distribution(probs, [1, 2, 3]) + + result = torch.mean(dist) + + assert isinstance(result, torch.Tensor) + + def test_shape_property(self): + """Test shape property returns probability tensor shape.""" + probs = torch.tensor([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3]]) + dist = Distribution(probs, [1, 2, 3]) + + assert dist.shape == torch.Size([2, 3]) + + def test_device_property(self): + """Test device property.""" + probs = torch.tensor([[0.2, 0.3, 0.5]]) + dist = Distribution(probs, [1, 2, 3]) + + assert dist.device == torch.device('cpu') + + def test_dim_method(self): + """Test dim method returns number of dimensions.""" + probs = torch.tensor([[0.2, 0.3, 0.5]]) + dist = Distribution(probs, [1, 2, 3]) + + assert dist.dim() == 2 + + def test_size_method(self): + """Test size method.""" + probs = torch.tensor([[0.2, 0.3, 0.5], [0.1, 0.6, 0.3]]) + dist = Distribution(probs, [1, 2, 3]) + + assert dist.size() == torch.Size([2, 3]) + assert dist.size(0) == 2 + assert dist.size(1) == 3 + + def test_to_device(self): + """Test moving Distribution to different device.""" + probs = torch.tensor([[0.2, 0.3, 0.5]]) + dist = Distribution(probs, [1, 2, 3]) + + # Move to CPU (should work on any system) + cpu_dist = dist.to('cpu') + + assert isinstance(cpu_dist, Distribution) + assert cpu_dist.device == torch.device('cpu') + assert list(cpu_dist.symbols) == list(dist.symbols) + + def test_cpu_method(self): + """Test cpu() method.""" + probs = torch.tensor([[0.2, 0.3, 0.5]]) + dist = Distribution(probs, [1, 2, 3]) + + cpu_dist = dist.cpu() + + assert cpu_dist.device == torch.device('cpu') + + def test_multiple_distributions_in_function(self): + """Test multiple Distributions passed to torch function.""" + probs1 = torch.tensor([[0.2, 0.3, 0.5]]) + probs2 = torch.tensor([[0.1, 0.4, 0.5]]) + dist1 = Distribution(probs1, ['a', 'b', 'c']) + dist2 = Distribution(probs2, ['x', 'y', 'z']) + + # torch.stack should convert both to tensors + result = torch.stack([dist1, dist2]) + + assert isinstance(result, torch.Tensor) + assert result.shape == torch.Size([2, 1, 3]) + + def test_torch_matmul(self): + """Test torch.matmul works with Distribution.""" + probs = torch.tensor([[0.2, 0.3, 0.5]]) + dist = Distribution(probs, [1, 2, 3]) + weight = torch.tensor([[1.0], [2.0], [3.0]]) + + result = torch.matmul(dist, weight) + + assert isinstance(result, torch.Tensor) + + # ============================================================================= # Tests for Edge Cases # ============================================================================= From 55d65f2306c3f3af188732f6884611ff274b10b9 Mon Sep 17 00:00:00 2001 From: Aaditya Naik Date: Thu, 5 Feb 2026 16:33:04 -0500 Subject: [PATCH 3/4] Added detailed tutorial for Dolphin --- README.md | 228 +++------------- tutorial.ipynb | 715 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 749 insertions(+), 194 deletions(-) create mode 100644 tutorial.ipynb diff --git a/README.md b/README.md index fa5638c..14ff82c 100644 --- a/README.md +++ b/README.md @@ -9,241 +9,81 @@ To install Dolphin, first clone the repository. Then use pip: pip install -e . ``` -## Tutorial: Writing Dolphin Programs +## Tutorial -This tutorial shows how to write neurosymbolic programs using Dolphin's decorator-based API. We'll build up from basic concepts to a complete working example. +For a complete walkthrough, see **[tutorial.ipynb](tutorial.ipynb)** or **[tutorial.py](tutorial.py)**, which train a model to compute `(a + b)² // c` from MNIST digit images. -### Core Concepts - -Dolphin operates on **Distributions** — objects that map symbolic values to probabilities. A neural network outputs logits over possible symbols, and Dolphin lets you perform symbolic computations while tracking probabilities through the computation. - -### Setup +### Quick Start ```python -import torch -import torch.nn as nn -import torch.nn.functional as F import dolphin from dolphin import Distribution from dolphin.provenances import get_provenance -# Set the provenance (probability computation method) +# Configure provenance Distribution.provenance = get_provenance("damp") ``` -### Step 1: Annotate Neural Networks with `@dolphin.distribution` - -Use `@dolphin.distribution(symbols)` to make a neural network's forward pass automatically return a Distribution instead of raw logits. +**1. Wrap neural networks with `@dolphin.distribution`** — outputs become Distributions over symbols: ```python @dolphin.distribution(range(10)) # Symbols are digits 0-9 -class DigitClassifier(nn.Module): - def __init__(self): - super().__init__() - self.conv1 = nn.Conv2d(1, 32, 3) - self.conv2 = nn.Conv2d(32, 64, 3) - self.fc1 = nn.Linear(1600, 128) - self.fc2 = nn.Linear(128, 10) - +class DigitNet(nn.Module): def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2(x), 2)) - x = x.view(-1, 1600) - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return F.softmax(x, dim=1) # Automatically wrapped into Distribution + # ... your network ... + return F.softmax(logits, dim=1) # Automatically wrapped ``` -Now when you call `model(image)`, it returns a `Distribution` with symbols `[0, 1, 2, ..., 9]` instead of a tensor. - -### Step 2: Define Symbolic Operations with `@dolphin.function` - -Use `@dolphin.function` to lift regular Python functions to work with Distributions. Write the function as if operating on individual symbols — Dolphin handles the probabilistic computation. +**2. Define symbolic operations with `@dolphin.function`** — write plain Python, Dolphin handles the probabilities: ```python @dolphin.function -def add(a, b): - return a + b - -@dolphin.function -def multiply(a, b): - return a * b - -@dolphin.function -def is_even(x): - return x % 2 == 0 +def formula(a, b, c): + return (a + b) ** 2 // c ``` -These functions now work seamlessly with both regular values and Distributions: +**3. Combine them in a model:** ```python -# With regular values (works normally) -result = add(3, 5) # Returns 8 - -# With Distributions (computes all possible outcomes) -d1 = model(image1) # Distribution over [0-9] -d2 = model(image2) # Distribution over [0-9] -sum_dist = add(d1, d2) # Distribution over [0-18] -``` - -### Step 3: Compose Operations - -Chain decorated functions to build complex symbolic programs: - -```python -@dolphin.function -def digit_formula(a, b, c): - """Compute a * b + c for digit distributions.""" - return a * b + c - -# Or compose simpler functions -result = add(multiply(d1, d2), d3) -``` - -### Step 4: Build a Trainable Model - -Wrap your symbolic computation in an `nn.Module`. The forward pass returns a Distribution, which you convert to probabilities at the end: - -```python -class MNISTSumN(nn.Module): - def __init__(self, n_digits=2): +class FormulaModel(nn.Module): + def __init__(self): super().__init__() - self.n_digits = n_digits - self.digit_net = DigitClassifier() # From Step 1 + self.digit_net = DigitNet() def forward(self, images): - # images: (batch_size, n_digits, 1, 28, 28) - result = self.digit_net(images[:, 0]) - for i in range(1, self.n_digits): - result = add(result, self.digit_net(images[:, i])) - return result.get_probabilities() # Convert to tensor for loss -``` - -The model takes a batch of image tuples and returns a probability tensor over possible sums. For 2 digits, the output shape is `(batch_size, 19)` — probabilities for sums 0 through 18. - -### Step 5: Compute Loss - -Use `.get_probabilities()` to get a standard PyTorch tensor, then apply any loss function: - -```python -model = MNISTSumN(n_digits=2) -images = batch_of_image_pairs # (batch, 2, 1, 28, 28) -target_sums = labels # (batch,) with values in [0, 18] - -output = model(images) # (batch, 19) probability tensor -loss = F.cross_entropy(output, target_sums) + a = self.digit_net(images[0]) + b = self.digit_net(images[1]) + c = self.digit_net(images[2]) + + result = formula(a, b, c, if_=lambda a, b, c: c != 0) # skip division by zero + return result.get_probabilities() ``` -The output of `.get_probabilities()` is a regular `torch.Tensor` with `requires_grad=True`, so it works with any PyTorch loss function. - -### Step 6: Backpropagation - -Call `loss.backward()` as usual. Gradients flow through the symbolic operations back to the neural network: +**4. Train with standard PyTorch** — gradients flow through symbolic ops: ```python -optimizer.zero_grad() output = model(images) -loss = F.cross_entropy(output, target_sums) -loss.backward() # Gradients propagate through add() back to DigitClassifier +loss = F.cross_entropy(output, targets) +loss.backward() # end-to-end differentiable! optimizer.step() ``` -All Dolphin operations (addition, multiplication, composition) are implemented with differentiable tensor operations, so the standard PyTorch autograd machinery works without modification. - -### Step 7: Training Loop - -Put it together into a standard training loop: - -```python -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -model = MNISTSumN(n_digits=2).to(device) -optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) - -for epoch in range(10): - model.train() - total_loss = 0 - - for images, target_sums in train_loader: - images = images.to(device) - target_sums = target_sums.to(device) - - optimizer.zero_grad() - output = model(images) - loss = F.cross_entropy(output, target_sums) - loss.backward() - optimizer.step() - - total_loss += loss.item() - - print(f"Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}") -``` - -### Step 8: Evaluation - -Evaluate using `argmax` to get the most likely sum: - -```python -model.eval() -correct = 0 - -with torch.no_grad(): - for images, target_sums in test_loader: - images = images.to(device) - target_sums = target_sums.to(device) - - output = model(images) - predictions = output.argmax(dim=1) - correct += (predictions == target_sums).sum().item() - -accuracy = correct / len(test_loader.dataset) -print(f"Test Accuracy: {accuracy:.2%}") -``` - -### Working Directly with Distributions - -You can also work with Distributions directly without decorators: +### Other Features ```python -# Create distributions manually -probs = torch.tensor([[0.1, 0.3, 0.6]]) -dist = Distribution(probs, ['cat', 'dog', 'bird']) - -# Arithmetic operations -d1 = Distribution(torch.tensor([[0.5, 0.5]]), [1, 2]) -d2 = Distribution(torch.tensor([[0.5, 0.5]]), [10, 20]) - -sum_dist = d1 + d2 # Symbols: [11, 12, 21, 22] -prod_dist = d1 * d2 # Symbols: [10, 20, 20, 40] - -# Map and filter -doubled = d1.map(lambda x: x * 2) # Symbols: [2, 4] -filtered = d1.filter(lambda x: x > 1) # Symbols: [2] +# Conditional computation +result = formula(a, b, c, if_=lambda a, b, c: c != 0) -# Apply custom functions -result = d1.apply(d2, lambda a, b: a ** b) -``` - -### Conditional Operations - -Use the `if_=` parameter for conditional computations: - -```python -@dolphin.function -def safe_divide(a, b): - return a / b +# Map output to fixed size +result.map_symbols(list(range(num_classes))).get_probabilities() -# Only compute where b != 0 -result = safe_divide(d1, d2, if_=lambda a, b: b != 0) +# Direct Distribution operations +d1 + d2 # add +d1 * d2 # multiply +d1.map(lambda x: x * 2) # transform symbols +d1.filter(lambda x: x > 0) # filter symbols ``` -### Key Points - -1. **`@dolphin.distribution(symbols)`** — Wrap neural network outputs as Distributions -2. **`@dolphin.function`** — Lift functions to work with Distributions -3. **Distributions track probabilities** — All operations maintain proper probability semantics -4. **`.get_probabilities()`** — Convert back to tensors for training -5. **GPU-accelerated** — All computations run efficiently on GPU - ## Running Experiments To run the experiments, you must first download the data. You can get it from the following drive link: diff --git a/tutorial.ipynb b/tutorial.ipynb new file mode 100644 index 0000000..aae13a5 --- /dev/null +++ b/tutorial.ipynb @@ -0,0 +1,715 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Dolphin Tutorial: Neurosymbolic Learning with PyTorch\n", + "\n", + "This notebook walks through building and training a neurosymbolic model using Dolphin. We'll train a neural network to compute `(a + b)² // c` from MNIST digit images, where `a`, `b`, and `c` are the digits shown in the images.\n", + "\n", + "The model only receives supervision on the final result — it must learn to recognize individual digits through the symbolic computation." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Setup\n", + "\n", + "Import Dolphin and configure the provenance (the method used for probability computation)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: mps\n" + ] + } + ], + "source": [ + "import os\n", + "import time\n", + "import random\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "import torchvision\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from tqdm import tqdm\n", + "\n", + "import dolphin\n", + "from dolphin import Distribution\n", + "from dolphin.provenances import get_provenance\n", + "\n", + "# Configure Dolphin\n", + "Distribution.provenance = get_provenance(\"damp\")\n", + "\n", + "# Device setup\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "print(f\"Using device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Define the Neural Network\n", + "\n", + "Use `@dolphin.distribution(symbols)` to make the network output a Distribution over digit symbols instead of raw logits." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "@dolphin.distribution(range(10)) # Symbols are digits 0-9\n", + "class DigitNet(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = nn.Conv2d(1, 32, kernel_size=5)\n", + " self.conv2 = nn.Conv2d(32, 64, kernel_size=5)\n", + " self.fc1 = nn.Linear(1024, 256)\n", + " self.fc2 = nn.Linear(256, 10)\n", + "\n", + " def forward(self, x):\n", + " x = F.relu(F.max_pool2d(self.conv1(x), 2))\n", + " x = F.relu(F.max_pool2d(self.conv2(x), 2))\n", + " x = x.view(-1, 1024)\n", + " x = F.relu(self.fc1(x))\n", + " x = F.dropout(x, p=0.5, training=self.training)\n", + " x = self.fc2(x)\n", + " return F.softmax(x, dim=1) # Wrapped into Distribution automatically" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Define Symbolic Operations\n", + "\n", + "Use `@dolphin.function` to lift Python functions to work with Distributions. Write the function as if operating on regular values — Dolphin handles the probabilistic computation." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "@dolphin.function\n", + "def formula(a, b, c):\n", + " return (a + b) ** 2 // c" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Build the Model\n", + "\n", + "Combine the neural network and symbolic operations into a trainable `nn.Module`. The `if_` parameter filters out invalid cases (division by zero), and `map_symbols` ensures a fixed output size." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "# Max output value: (9+9)^2 // 1 = 324, so we need 325 classes (0-324)\n", + "NUM_CLASSES = 325\n", + "\n", + "class FormulaModel(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.digit_net = DigitNet()\n", + "\n", + " def forward(self, images):\n", + " # images: tuple of 3 tensors, each (batch_size, 1, 28, 28)\n", + " a = self.digit_net(images[0])\n", + " b = self.digit_net(images[1])\n", + " c = self.digit_net(images[2])\n", + " \n", + " # Skip c=0 to avoid division by zero\n", + " result = formula(a, b, c, if_=lambda a, b, c: c != 0)\n", + " \n", + " # Map to fixed output size\n", + " return result.map_symbols(list(range(NUM_CLASSES))).get_probabilities()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Create MNIST Dataset\n", + "\n", + "Load MNIST and create triplets of digit images. Each sample returns three images and the target `(a + b)² // c`." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "mnist_transform = torchvision.transforms.Compose([\n", + " torchvision.transforms.ToTensor(),\n", + " torchvision.transforms.Normalize((0.1307,), (0.3081,))\n", + "])\n", + "\n", + "class MNISTFormulaDataset(Dataset):\n", + " \"\"\"Dataset that returns triplets of MNIST digits for computing (a + b)^2 // c.\"\"\"\n", + " \n", + " def __init__(self, root, train=True, download=True):\n", + " self.mnist = torchvision.datasets.MNIST(\n", + " root, train=train, download=download, transform=mnist_transform\n", + " )\n", + " self.index_map = list(range(len(self.mnist)))\n", + " random.shuffle(self.index_map)\n", + " \n", + " def __len__(self):\n", + " return len(self.mnist) // 3\n", + " \n", + " def __getitem__(self, idx):\n", + " img_a, digit_a = self.mnist[self.index_map[idx * 3]]\n", + " img_b, digit_b = self.mnist[self.index_map[idx * 3 + 1]]\n", + " img_c, digit_c = self.mnist[self.index_map[idx * 3 + 2]]\n", + " \n", + " if digit_c == 0:\n", + " target = 0\n", + " valid = False\n", + " else:\n", + " target = (digit_a + digit_b) ** 2 // digit_c\n", + " valid = True\n", + " \n", + " return img_a, img_b, img_c, target, valid\n", + "\n", + " @staticmethod\n", + " def collate_fn(batch):\n", + " imgs_a = torch.stack([item[0] for item in batch])\n", + " imgs_b = torch.stack([item[1] for item in batch])\n", + " imgs_c = torch.stack([item[2] for item in batch])\n", + " targets = torch.tensor([item[3] for item in batch]).long()\n", + " valids = torch.tensor([item[4] for item in batch])\n", + " return (imgs_a, imgs_b, imgs_c), targets, valids" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Define Loss and Trainer\n", + "\n", + "Use BCE loss with one-hot encoding (same as the experiments in `sum_n.py`). The Trainer class handles the training loop with tqdm progress bars." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "def bce_loss(output, ground_truth):\n", + " \"\"\"BCE loss with one-hot encoding, same as sum_n.py.\"\"\"\n", + " gt = F.one_hot(ground_truth, num_classes=NUM_CLASSES).float()\n", + " return F.binary_cross_entropy(output, gt)\n", + "\n", + "class Trainer:\n", + " def __init__(self, train_loader, test_loader, learning_rate=1e-3):\n", + " self.device = device\n", + " self.network = FormulaModel().to(self.device)\n", + " self.optimizer = torch.optim.Adam(self.network.parameters(), lr=learning_rate)\n", + " self.train_loader = train_loader\n", + " self.test_loader = test_loader\n", + " self.best_loss = float('inf')\n", + " self.best_acc = 0\n", + " self.epoch_times = []\n", + "\n", + " def train_epoch(self, epoch):\n", + " self.network.train()\n", + " t_begin = time.time()\n", + " \n", + " iter = tqdm(self.train_loader, total=len(self.train_loader))\n", + " for (imgs, targets, valids) in iter:\n", + " if not valids.any():\n", + " continue\n", + " \n", + " imgs = tuple(img[valids].to(self.device) for img in imgs)\n", + " targets = targets[valids].to(self.device)\n", + " \n", + " self.optimizer.zero_grad()\n", + " output = self.network(imgs)\n", + " \n", + " targets = targets.clamp(0, NUM_CLASSES - 1)\n", + " \n", + " loss = bce_loss(output, targets)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " \n", + " iter.set_description(f\"[Train Epoch {epoch}] Loss: {loss.item():.4f}\")\n", + " \n", + " epoch_time = time.time() - t_begin\n", + " self.epoch_times.append(epoch_time)\n", + " print(f\"Epoch time: {epoch_time:.2f}s\")\n", + "\n", + " def test_epoch(self, epoch):\n", + " self.network.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " total = 0\n", + " \n", + " with torch.no_grad():\n", + " iter = tqdm(self.test_loader, total=len(self.test_loader))\n", + " for (imgs, targets, valids) in iter:\n", + " if not valids.any():\n", + " continue\n", + " \n", + " imgs = tuple(img[valids].to(self.device) for img in imgs)\n", + " targets = targets[valids].to(self.device)\n", + " \n", + " output = self.network(imgs)\n", + " \n", + " targets = targets.clamp(0, NUM_CLASSES - 1)\n", + " \n", + " test_loss += bce_loss(output, targets).item()\n", + " pred = output.argmax(dim=1)\n", + " correct += pred.eq(targets).sum().item()\n", + " total += targets.shape[0]\n", + " \n", + " acc = 100. * correct / total\n", + " iter.set_description(f\"[Test Epoch {epoch}] Loss: {test_loss:.4f}, Acc: {correct}/{total} ({acc:.2f}%)\")\n", + " \n", + " acc = 100. * correct / total if total > 0 else 0\n", + " if test_loss < self.best_loss:\n", + " self.best_loss = test_loss\n", + " if acc > self.best_acc:\n", + " self.best_acc = acc\n", + " \n", + " print(f\"Best loss: {self.best_loss:.4f}, Best acc: {self.best_acc:.2f}%\")\n", + "\n", + " def train(self, n_epochs):\n", + " self.test_epoch(0)\n", + " for epoch in range(1, n_epochs + 1):\n", + " self.train_epoch(epoch)\n", + " self.test_epoch(epoch)\n", + " \n", + " if self.epoch_times:\n", + " avg_time = sum(self.epoch_times) / len(self.epoch_times)\n", + " print(f\"Average epoch time: {avg_time:.2f}s\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Train the Model\n", + "\n", + "Create the datasets, dataloaders, and run training." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train samples: 20000, Test samples: 3333\n" + ] + } + ], + "source": [ + "# Hyperparameters\n", + "n_epochs = 10\n", + "batch_size = 64\n", + "learning_rate = 1e-3\n", + "seed = 42\n", + "\n", + "# Set seeds for reproducibility\n", + "torch.manual_seed(seed)\n", + "random.seed(seed)\n", + "\n", + "# Data directory\n", + "data_dir = \"./data\"\n", + "\n", + "# Create datasets\n", + "train_dataset = MNISTFormulaDataset(data_dir, train=True, download=True)\n", + "test_dataset = MNISTFormulaDataset(data_dir, train=False, download=True)\n", + "\n", + "train_loader = DataLoader(\n", + " train_dataset, \n", + " batch_size=batch_size, \n", + " shuffle=True, \n", + " collate_fn=MNISTFormulaDataset.collate_fn\n", + ")\n", + "test_loader = DataLoader(\n", + " test_dataset, \n", + " batch_size=batch_size, \n", + " collate_fn=MNISTFormulaDataset.collate_fn\n", + ")\n", + "\n", + "print(f\"Train samples: {len(train_dataset)}, Test samples: {len(test_dataset)}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 0] Loss: 0.7886, Acc: 157/3035 (5.17%): 100%|██████████| 53/53 [00:01<00:00, 39.22it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.7886, Best acc: 5.17%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 1] Loss: 0.0008: 100%|██████████| 313/313 [00:10<00:00, 30.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 10.34s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 1] Loss: 0.0455, Acc: 2883/3035 (94.99%): 100%|██████████| 53/53 [00:00<00:00, 61.40it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0455, Best acc: 94.99%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 2] Loss: 0.0004: 100%|██████████| 313/313 [00:09<00:00, 32.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 9.66s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 2] Loss: 0.0286, Acc: 2944/3035 (97.00%): 100%|██████████| 53/53 [00:00<00:00, 64.39it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0286, Best acc: 97.00%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 3] Loss: 0.0010: 100%|██████████| 313/313 [00:09<00:00, 33.37it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 9.38s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 3] Loss: 0.0251, Acc: 2945/3035 (97.03%): 100%|██████████| 53/53 [00:00<00:00, 62.20it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0251, Best acc: 97.03%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 4] Loss: 0.0000: 100%|██████████| 313/313 [00:09<00:00, 32.34it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 9.68s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 4] Loss: 0.0265, Acc: 2949/3035 (97.17%): 100%|██████████| 53/53 [00:00<00:00, 58.34it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0251, Best acc: 97.17%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 5] Loss: 0.0003: 100%|██████████| 313/313 [00:09<00:00, 32.43it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 9.65s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 5] Loss: 0.0206, Acc: 2967/3035 (97.76%): 100%|██████████| 53/53 [00:00<00:00, 59.53it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0206, Best acc: 97.76%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 6] Loss: 0.0002: 100%|██████████| 313/313 [00:10<00:00, 31.27it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 10.01s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 6] Loss: 0.0220, Acc: 2966/3035 (97.73%): 100%|██████████| 53/53 [00:00<00:00, 61.00it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0206, Best acc: 97.76%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 7] Loss: 0.0001: 100%|██████████| 313/313 [00:09<00:00, 33.12it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 9.45s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 7] Loss: 0.0194, Acc: 2984/3035 (98.32%): 100%|██████████| 53/53 [00:00<00:00, 59.16it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0194, Best acc: 98.32%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 8] Loss: 0.0000: 100%|██████████| 313/313 [00:09<00:00, 32.28it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 9.70s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 8] Loss: 0.0183, Acc: 2980/3035 (98.19%): 100%|██████████| 53/53 [00:00<00:00, 64.31it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0183, Best acc: 98.32%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 9] Loss: 0.0000: 100%|██████████| 313/313 [00:09<00:00, 33.48it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 9.35s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 9] Loss: 0.0186, Acc: 2976/3035 (98.06%): 100%|██████████| 53/53 [00:00<00:00, 64.17it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0183, Best acc: 98.32%\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Train Epoch 10] Loss: 0.0002: 100%|██████████| 313/313 [00:09<00:00, 32.91it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch time: 9.51s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[Test Epoch 10] Loss: 0.0176, Acc: 2978/3035 (98.12%): 100%|██████████| 53/53 [00:00<00:00, 63.34it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Best loss: 0.0176, Best acc: 98.32%\n", + "Average epoch time: 9.67s\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "# Create trainer and train\n", + "trainer = Trainer(train_loader, test_loader, learning_rate)\n", + "trainer.train(n_epochs)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## What we covered\n", + "\n", + "That's the basic Dolphin workflow. The main things to remember:\n", + "\n", + "- `@dolphin.distribution` turns your network's output into a Distribution over symbols\n", + "- `@dolphin.function` lets you write symbolic logic as plain Python — Dolphin figures out the probabilities\n", + "- Use `if_=` when you need to skip certain cases (like c=0 here)\n", + "- `map_symbols()` pads the output to a fixed size when your formula doesn't hit every possible value\n", + "\n", + "Everything else is just standard PyTorch. Gradients flow through the symbolic ops, so training works as expected." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 7a5cb321f880c9ecbf281eeb84647f8e33a78b13 Mon Sep 17 00:00:00 2001 From: Aaditya Naik Date: Tue, 17 Mar 2026 20:55:04 -0400 Subject: [PATCH 4/4] Update README.md --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 14ff82c..5330941 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,7 @@ d1.filter(lambda x: x > 0) # filter symbols ## Running Experiments -To run the experiments, you must first download the data. You can get it from the following drive link: +To run the experiments, you must first download the data. You can get it from the following drive link: https://drive.google.com/file/d/1cP1W9OluX_lOUWn6jZ9QpEb5qtlopXCP/view?usp=sharing ### MNIST Sum-N @@ -117,4 +117,4 @@ python run.py --cuda --n-epochs=10 --seed 1831 --learning-rate 1e-5 ```bash cd experiments/mugen python run.py --phase=train --train_size=1000 --provenance=damp --seed=1234 --epochs=100 --use_cuda -``` \ No newline at end of file +```