Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
dash[diskcache]~=3.2
dash-bootstrap-components~=2.0
dwave-ocean-sdk~=9.0
dwave-pytorch-plugin~=0.2
dwave-pytorch-plugin~=0.3
einops~=0.8
matplotlib~=3.10
pyyaml~=6.0
Expand Down
72 changes: 1 addition & 71 deletions src/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,85 +26,15 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional

import torch
from dimod import Sampler, SampleSet

if TYPE_CHECKING:
from dwave.plugins.torch.models import GraphRestrictedBoltzmannMachine

from .utils.persistent_qpu_sampler import PersistentQPUSampleHelper


class RadialBasisFunction(torch.nn.Module):
"""Radial basis function with multiple bandwidth parameters."""

def __init__(
self,
num_features: int,
mul_factor: Union[int, float] = 2.0,
bandwidth: Optional[float] = None,
):
super().__init__()
bandwidth_multipliers = mul_factor ** (torch.arange(num_features) - num_features // 2)
self.register_buffer("bandwidth_multipliers", bandwidth_multipliers)
self.bandwidth = bandwidth

def get_bandwidth(
self, l2_distance_matrix: Optional[torch.Tensor] = None
) -> Union[torch.Tensor, float]:
"""A heuristic method for determining an appropriate bandwidth for the radial basis function."""
if self.bandwidth is None:
assert l2_distance_matrix is not None

num_samples = l2_distance_matrix.shape[0]

return l2_distance_matrix.sum() / (num_samples * (num_samples - 1))

return self.bandwidth

def forward(self, x: torch.Tensor) -> torch.Tensor:
distance_matrix = torch.cdist(x, x, p=2)
bandwidth = self.get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers

return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0)


def mmd_loss(
spins: torch.Tensor,
kernel: RadialBasisFunction,
grbm: GraphRestrictedBoltzmannMachine,
sampler: Sampler,
sampler_kwargs: dict,
linear_range: tuple[float, float],
quadratic_range: tuple[float, float],
prefactor: float,
) -> float:
"""Computes an unbiased estimate of the maximum mean discrepancy metric."""
with torch.no_grad():
samples = grbm.sample(
sampler,
prefactor=prefactor,
device=spins.device,
linear_range=linear_range,
quadratic_range=quadratic_range,
sample_params=sampler_kwargs,
)

spins = spins.reshape(-1, spins.shape[-1])

kernel_matrix = kernel(torch.vstack((spins, samples)))
num_spin_strings = spins.shape[0]
spin_spin_kernels = kernel_matrix[:num_spin_strings, :num_spin_strings]
sample_sample_kernels = kernel_matrix[num_spin_strings:, num_spin_strings:]
spin_sample_kernels = kernel_matrix[:num_spin_strings, num_spin_strings:]

mmd = spin_spin_kernels.mean() - 2 * spin_sample_kernels.mean() + sample_sample_kernels.mean()

return mmd


def nll_loss(
spins: torch.Tensor,
grbm: GraphRestrictedBoltzmannMachine,
Expand Down
31 changes: 19 additions & 12 deletions src/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@
import plotly.express as px
import torch
import yaml

from dwave.plugins.torch.models import (
DiscreteVariationalAutoencoder,
GraphRestrictedBoltzmannMachine,
)
from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss
from dwave.plugins.torch.nn.modules.kernels import GaussianKernel

from einops import rearrange
from plotly import graph_objects as go
from torch.utils.data import DataLoader
Expand All @@ -36,7 +40,7 @@

from .decoder import Decoder
from .encoder import Encoder
from .losses import RadialBasisFunction, mmd_loss, nll_loss
from .losses import nll_loss
from .utils.common import get_latent_to_discrete, get_sampler_and_sampler_kwargs
from .utils.persistent_qpu_sampler import PersistentQPUSampleHelper

Expand Down Expand Up @@ -266,7 +270,7 @@ def train_init(
self._tpar["opt_step"] = 0

# use for self.LOSS_FUNCTION == "mmd":
self._tpar["kernel"] = RadialBasisFunction(num_features=7).to(self._device)
self._tpar["kernel"] = GaussianKernel(n_kernels=7).to(self._device)

self._tpar["sample_set"] = None

Expand Down Expand Up @@ -301,16 +305,19 @@ def step(self, batch: tuple[torch.Tensor, torch.Tensor], epoch: int) -> torch.Te
)
self.losses["mse_losses"].append(mse_loss.item())

_mmd_loss = mmd_loss(
spins=spins,
kernel=self._tpar["kernel"],
grbm=self._grbm,
sampler=self.sampler,
sampler_kwargs=self.sampler_kwargs,
linear_range=self.linear_range,
quadratic_range=self.quadratic_range,
prefactor=self.PREFACTOR,
)
with torch.no_grad():
samples = self._grbm.sample( # type: ignore
sampler=self.sampler,
prefactor=self.PREFACTOR,
linear_range=self.linear_range,
quadratic_range=self.quadratic_range,
device=spins.device,
sample_params=self.sampler_kwargs,
)

spins = spins.reshape(-1, spins.shape[-1])

_mmd_loss = maximum_mean_discrepancy_loss(x=spins, y=samples, kernel=self._tpar["kernel"])

dvae_loss = mse_loss + _mmd_loss

Expand Down