diff --git a/requirements.txt b/requirements.txt index 9fc61fd..cd4504d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ dash[diskcache]~=3.2 dash-bootstrap-components~=2.0 dwave-ocean-sdk~=9.0 -dwave-pytorch-plugin~=0.2 +dwave-pytorch-plugin~=0.3 einops~=0.8 matplotlib~=3.10 pyyaml~=6.0 diff --git a/src/losses.py b/src/losses.py index 347e059..2e417ba 100644 --- a/src/losses.py +++ b/src/losses.py @@ -26,85 +26,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional import torch from dimod import Sampler, SampleSet if TYPE_CHECKING: from dwave.plugins.torch.models import GraphRestrictedBoltzmannMachine - from .utils.persistent_qpu_sampler import PersistentQPUSampleHelper - -class RadialBasisFunction(torch.nn.Module): - """Radial basis function with multiple bandwidth parameters.""" - - def __init__( - self, - num_features: int, - mul_factor: Union[int, float] = 2.0, - bandwidth: Optional[float] = None, - ): - super().__init__() - bandwidth_multipliers = mul_factor ** (torch.arange(num_features) - num_features // 2) - self.register_buffer("bandwidth_multipliers", bandwidth_multipliers) - self.bandwidth = bandwidth - - def get_bandwidth( - self, l2_distance_matrix: Optional[torch.Tensor] = None - ) -> Union[torch.Tensor, float]: - """A heuristic method for determining an appropriate bandwidth for the radial basis function.""" - if self.bandwidth is None: - assert l2_distance_matrix is not None - - num_samples = l2_distance_matrix.shape[0] - - return l2_distance_matrix.sum() / (num_samples * (num_samples - 1)) - - return self.bandwidth - - def forward(self, x: torch.Tensor) -> torch.Tensor: - distance_matrix = torch.cdist(x, x, p=2) - bandwidth = self.get_bandwidth(distance_matrix.detach()) * self.bandwidth_multipliers - - return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) - - -def mmd_loss( - spins: torch.Tensor, - kernel: RadialBasisFunction, - grbm: GraphRestrictedBoltzmannMachine, - sampler: Sampler, - sampler_kwargs: dict, - linear_range: tuple[float, float], - quadratic_range: tuple[float, float], - prefactor: float, -) -> float: - """Computes an unbiased estimate of the maximum mean discrepancy metric.""" - with torch.no_grad(): - samples = grbm.sample( - sampler, - prefactor=prefactor, - device=spins.device, - linear_range=linear_range, - quadratic_range=quadratic_range, - sample_params=sampler_kwargs, - ) - - spins = spins.reshape(-1, spins.shape[-1]) - - kernel_matrix = kernel(torch.vstack((spins, samples))) - num_spin_strings = spins.shape[0] - spin_spin_kernels = kernel_matrix[:num_spin_strings, :num_spin_strings] - sample_sample_kernels = kernel_matrix[num_spin_strings:, num_spin_strings:] - spin_sample_kernels = kernel_matrix[:num_spin_strings, num_spin_strings:] - - mmd = spin_spin_kernels.mean() - 2 * spin_sample_kernels.mean() + sample_sample_kernels.mean() - - return mmd - - def nll_loss( spins: torch.Tensor, grbm: GraphRestrictedBoltzmannMachine, diff --git a/src/model_wrapper.py b/src/model_wrapper.py index 5da1a5d..19596a1 100755 --- a/src/model_wrapper.py +++ b/src/model_wrapper.py @@ -21,10 +21,14 @@ import plotly.express as px import torch import yaml + from dwave.plugins.torch.models import ( DiscreteVariationalAutoencoder, GraphRestrictedBoltzmannMachine, ) +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss +from dwave.plugins.torch.nn.modules.kernels import GaussianKernel + from einops import rearrange from plotly import graph_objects as go from torch.utils.data import DataLoader @@ -36,7 +40,7 @@ from .decoder import Decoder from .encoder import Encoder -from .losses import RadialBasisFunction, mmd_loss, nll_loss +from .losses import nll_loss from .utils.common import get_latent_to_discrete, get_sampler_and_sampler_kwargs from .utils.persistent_qpu_sampler import PersistentQPUSampleHelper @@ -266,7 +270,7 @@ def train_init( self._tpar["opt_step"] = 0 # use for self.LOSS_FUNCTION == "mmd": - self._tpar["kernel"] = RadialBasisFunction(num_features=7).to(self._device) + self._tpar["kernel"] = GaussianKernel(n_kernels=7).to(self._device) self._tpar["sample_set"] = None @@ -301,16 +305,19 @@ def step(self, batch: tuple[torch.Tensor, torch.Tensor], epoch: int) -> torch.Te ) self.losses["mse_losses"].append(mse_loss.item()) - _mmd_loss = mmd_loss( - spins=spins, - kernel=self._tpar["kernel"], - grbm=self._grbm, - sampler=self.sampler, - sampler_kwargs=self.sampler_kwargs, - linear_range=self.linear_range, - quadratic_range=self.quadratic_range, - prefactor=self.PREFACTOR, - ) + with torch.no_grad(): + samples = self._grbm.sample( # type: ignore + sampler=self.sampler, + prefactor=self.PREFACTOR, + linear_range=self.linear_range, + quadratic_range=self.quadratic_range, + device=spins.device, + sample_params=self.sampler_kwargs, + ) + + spins = spins.reshape(-1, spins.shape[-1]) + + _mmd_loss = maximum_mean_discrepancy_loss(x=spins, y=samples, kernel=self._tpar["kernel"]) dvae_loss = mse_loss + _mmd_loss