diff --git a/configs/train/tokenizer.yaml b/configs/train/tokenizer.yaml index 1b27685..4fb87fc 100644 --- a/configs/train/tokenizer.yaml +++ b/configs/train/tokenizer.yaml @@ -12,7 +12,7 @@ experiment: model: vq_model: - finetune_decoder: True + finetune_decoder: False codebook_size: 512 token_size: 32 use_l2_norm: True @@ -23,6 +23,7 @@ model: vit_enc_patch_size: 16 vit_dec_patch_size: 16 num_latent_tokens: 128 + predict_pixels: False losses: discriminator_start: 2_000 diff --git a/models/__init__.py b/models/__init__.py index d2d877f..4df3f91 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,6 +1,6 @@ # models/__init__.py from .base_model import BaseModel -from .blocks import Attention, ResidualAttention +from .blocks import Attention, ResidualAttention, Residual, ResidualStack from .decoder import Decoder from .ema import EMAModel from .encoder import Encoder @@ -16,6 +16,8 @@ "Decoder", "VectorQuantizer", "ResidualAttention", + "ResidualStack", + "Residual", "Attention", "EMAModel", "ReconstructionLoss", diff --git a/models/blocks.py b/models/blocks.py index 5ed57fc..edd93fc 100644 --- a/models/blocks.py +++ b/models/blocks.py @@ -5,9 +5,92 @@ from collections import OrderedDict import torch from torch import nn +import torch.nn.functional as F import einops +class Residual(nn.Module): + """ + Residual Convolutional Layer + """ + + def __init__(self, in_channels, num_hiddens, num_residual_hiddens): + """ + initialize residual CNN layer + + :param in_channels number: Number of input channels + :param num_hiddens number: Number of hidden channels + :param num_residual_hiddens number: Number of residual hiddens + """ + super(Residual, self).__init__() + self._block = nn.Sequential( + nn.ReLU(True), + # C: 3 -> residual hidden + nn.Conv2d( + in_channels=in_channels, + out_channels=num_residual_hiddens, + kernel_size=3, + stride=1, + padding=1, + bias=False, + ), + nn.ReLU(True), + # C: resiual hidden -> out_hidden + nn.Conv2d( + in_channels=num_residual_hiddens, + out_channels=num_hiddens, + kernel_size=1, + stride=1, + bias=False, + ), + ) + + def forward(self, x): + """ + Residual layer + + :param x numpy.ndarray: Input image + """ + return x + self._block(x) # residual output + + +class ResidualStack(nn.Module): + """ + Residual Convolution Stack + """ + + def __init__( + self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens + ): + """ + initialize residual stack + + :param in_channels number: Number of input channels + :param num_hiddens number: Number of hidden channels + :param num_residual_layers number: Number of residual layers in stack + :param num_residual_hiddens number: Number of hidden residual channels + """ + super(ResidualStack, self).__init__() + self._num_residual_layers = num_residual_layers + self._layers = nn.ModuleList( + [ + Residual(in_channels, num_hiddens, num_residual_hiddens) + for _ in range(self._num_residual_layers) + ] + ) + + def forward(self, x): + """ + Apply residual stack + + :param x numpy.ndarray: Input image + """ + # Apply all residual layers in stack + for i in range(self._num_residual_layers): + x = self._layers[i](x) + return F.relu(x) + + class ResidualAttention(nn.Module): """ Residual Attention Block diff --git a/models/cnn_vqvae.py b/models/cnn_vqvae.py new file mode 100644 index 0000000..7dd3b63 --- /dev/null +++ b/models/cnn_vqvae.py @@ -0,0 +1,264 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .blocks import ResidualStack + + +class VectorQuantizeEMA(nn.Module): + """ + Exponential Moving Average (EMA) vector quantization for VQ-VAE model + """ + + def __init__( + self, n_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5 + ): + """ + initialize VQ class + + :param n_embeddings number: number of discrete embeddings + :param embedding_dim number: dimension of embeddings + :param commitment_cost number: commitment cost weight + :param decay number: decay rate for EMA + :param epsilon number: epsilon value for EMA + """ + super(VectorQuantizeEMA, self).__init__() + + self._embedding_dim = embedding_dim # Dimension of an embedding vector, D + self._n_embeddings = n_embeddings # Number of categories in distribution, K + + # Parameters + self._embedding = nn.Embedding( + self._n_embeddings, self._embedding_dim + ) # Embedding table for categorical distribution + self._embedding.weight.data.normal_() # Randomly initialize embeddings + self.register_buffer( + "_ema_cluster_size", torch.zeros(n_embeddings) + ) # Clusters for EMA + self._ema_w = nn.Parameter( + torch.Tensor(n_embeddings, self._embedding_dim) + ) # EMA weights + self._ema_w.data.normal_() + + # Loss / Training Parameters + self._commitment_cost = commitment_cost + self._decay = decay + self._epsilon = epsilon + + def forward(self, z_e): + """ + Quantize embeddings + + :param z_e numpy.ndarray: Embeddings from encoder to quantize + """ + # reshape from BCHW -> BHWC + z_e = z_e.permute(0, 2, 3, 1).contiguous() + shape = z_e.shape + + # Flatten input embeddings + flat_z_e = z_e.view(-1, self._embedding_dim) + + # Claculate Distances + # ||z_e||^2 + ||e||^2 - 2 * z_q + distances = ( + torch.sum(flat_z_e**2, dim=1, keepdim=True) + + torch.sum(self._embedding.weight**2, dim=1) + - 2 * torch.matmul(flat_z_e, self._embedding.weight.t()) + ) + + # Encoding + encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) + encodings = torch.zeros( + encoding_indices.shape[0], self._n_embeddings, device=z_e.device + ) + # Convert to shape of embeddings + encodings.scatter_(1, encoding_indices, 1) + + # Quantize and Unflatten + z_q = torch.matmul(encodings, self._embedding.weight).view(shape) + + # Update weights with EMA + if self.training: + self._ema_cluster_size = self._ema_cluster_size * self._decay + ( + 1 - self._decay + ) * torch.sum(encodings, 0) + + # Laplace smoothing + n = torch.sum(self._ema_cluster_size.data) + self._ema_cluster_size = ( + (self._ema_cluster_size + self._epsilon) + / (n + self._n_embeddings * self._epsilon) + * n + ) + + dw = torch.matmul(encodings.t(), flat_z_e) + self._ema_w = nn.Parameter( + self._ema_w * self._decay + (1 - self._decay) * dw + ) + + self._embedding.weight = nn.Parameter( + self._ema_w / self._ema_cluster_size.unsqueeze(1) + ) + + # Loss + e_latent_loss = F.mse_loss( + z_q.detach(), z_e + ) # distance from encoder output and quantized embeddings + loss = self._commitment_cost * e_latent_loss + + # Straight Through Loss + z_q = z_e + (z_q - z_e).detach() + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + # Convert shape back to BCHW + return loss, z_q.permute(0, 3, 1, 2).contiguous(), perplexity, encodings + + +class Encoder(nn.Module): + """ + Convolutional Encoder producing a 16x16 grid from 256x256 input + """ + + def __init__( + self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens + ): + """ + Initialize encoder network + :param in_channels: Number of input channels + :param num_hiddens: Number of hidden channels + :param num_residual_layers: Number of layers in residual stack + :param num_residual_hiddens: Number of channels in residual hidden layer + """ + super(Encoder, self).__init__() + + # Modify convolution layers to achieve 16x16 output from 256x256 input + self._conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=num_hiddens // 2, + kernel_size=4, + stride=2, + padding=1, # 256 -> 128 + ) + self._conv2 = nn.Conv2d( + in_channels=num_hiddens // 2, + out_channels=num_hiddens, + kernel_size=4, + stride=2, + padding=1, # 128 -> 64 + ) + self._conv3 = nn.Conv2d( + in_channels=num_hiddens, + out_channels=num_hiddens, + kernel_size=4, + stride=2, + padding=1, # 64 -> 32 + ) + self._conv4 = nn.Conv2d( + in_channels=num_hiddens, + out_channels=num_hiddens, + kernel_size=4, + stride=2, + padding=1, # 32 -> 16 + ) + + self._residual_stack = ResidualStack( + in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens, + ) + + def forward(self, inputs): + """ + Encode image + :param inputs: images to encode (256x256) + :return: latent representation (16x16) + """ + x = self._conv1(inputs) + x = F.relu(x) + x = self._conv2(x) + x = F.relu(x) + x = self._conv3(x) + x = F.relu(x) + x = self._conv4(x) + return self._residual_stack(x) + + +class Decoder(nn.Module): + """ + Convolutional Decoder reconstructing 256x256 from 16x16 input + """ + + def __init__( + self, in_channels, num_hiddens, num_residual_layers, num_residual_hiddens + ): + """ + Initialize decoder network + :param in_channels: Number of input channels + :param num_hiddens: Number of hidden channels + :param num_residual_layers: Number of residual layers in stack + :param num_residual_hiddens: Number of channels in residual + """ + super(Decoder, self).__init__() + + self._conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=num_hiddens, + kernel_size=3, + stride=1, + padding=1, + ) + self._residual_stack = ResidualStack( + in_channels=num_hiddens, + num_hiddens=num_hiddens, + num_residual_layers=num_residual_layers, + num_residual_hiddens=num_residual_hiddens, + ) + + # Add upsampling convolution transpose layers to match Encoder + self._conv_trans_1 = nn.ConvTranspose2d( + in_channels=num_hiddens, + out_channels=num_hiddens // 2, + kernel_size=4, + stride=2, + padding=1, # 16 -> 32 + ) + self._conv_trans_2 = nn.ConvTranspose2d( + in_channels=num_hiddens // 2, + out_channels=num_hiddens // 4, + kernel_size=4, + stride=2, + padding=1, # 32 -> 64 + ) + self._conv_trans_3 = nn.ConvTranspose2d( + in_channels=num_hiddens // 4, + out_channels=num_hiddens // 8, + kernel_size=4, + stride=2, + padding=1, # 64 -> 128 + ) + self._conv_trans_4 = nn.ConvTranspose2d( + in_channels=num_hiddens // 8, + out_channels=3, + kernel_size=4, + stride=2, + padding=1, # 128 -> 256 + ) + + def forward(self, inputs): + """ + Decode latent embeddings + :param inputs: latent embeddings (16x16) + :return: reconstructed image (256x256) + """ + x = self._conv1(inputs) + x = self._residual_stack(x) + x = self._conv_trans_1(x) + x = F.relu(x) + x = self._conv_trans_2(x) + x = F.relu(x) + x = self._conv_trans_3(x) + x = F.relu(x) + return self._conv_trans_4(x) diff --git a/models/decoder.py b/models/decoder.py index 52fe580..66080dd 100644 --- a/models/decoder.py +++ b/models/decoder.py @@ -16,7 +16,7 @@ def __init__(self, grid_size): self.grid_size = grid_size def forward(self, x): - return x[:, 0 : self.grid_size**2] + return x[:, 0: self.grid_size**2] class Decoder(nn.Module): @@ -29,14 +29,16 @@ def __init__(self, config): self.model_size = config.model.vq_model.vit_dec_model_size self.num_latent_tokens = config.model.vq_model.num_latent_tokens self.token_size = config.model.vq_model.token_size + self.predict_pixels = config.model.vq_model.get( + "predict_pixels", False) self.width = { - "small": 128, + "small": 512, "base": 768, "large": 1024, }[self.model_size] self.num_layers = { - "small": 1, + "small": 4, "base": 12, "large": 24, }[self.model_size] @@ -65,16 +67,25 @@ def __init__(self, config): ) self.ln_post = nn.LayerNorm(self.width) # post attention layer norm # FFN to convert mask tokens to image patches - self.ffn = nn.Sequential( - nn.Conv2d(self.width, 3 * self.patch_size**2, 1, padding=0, bias=True), - Rearrange( - "B (P1 P2 C) H W -> B C (H P1) (W P2)", - P1=self.patch_size, - P2=self.patch_size, - ), - ) - # conv layer on pixel output - self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) + if self.predict_pixels: + self.ffn = nn.Sequential( + nn.Conv2d(self.width, 3 * self.patch_size ** + 2, 1, padding=0, bias=True), + Rearrange( + "B (P1 P2 C) H W -> B C (H P1) (W P2)", + P1=self.patch_size, + P2=self.patch_size, + ), + ) + # conv layer on pixel output + self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) + else: + self.ffn = nn.Sequential( + nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True), + nn.Tanh(), + nn.Conv2d(2 * self.width, 64, 1, padding=0, bias=True) + ) + self.conv_out = nn.Identity() self.model = nn.Sequential( self.ln_pre, @@ -83,7 +94,8 @@ def __init__(self, config): Rearrange("L B C -> B L C"), RemoveLatentTokens(grid_size=self.grid_size), self.ln_post, - Rearrange("B (H W) C -> B C H W", H=self.grid_size, W=self.grid_size), + Rearrange("B (H W) C -> B C H W", + H=self.grid_size, W=self.grid_size), self.ffn, self.conv_out ) @@ -95,9 +107,11 @@ def forward(self, z_q): x = einops.rearrange(z_q, "B C L -> B L C") x = self.decoder_embed(x) # embed tokens in model dim - mask_tokens = self.mask_token.repeat(B, self.grid_size**2, 1).to(x.dtype) + mask_tokens = self.mask_token.repeat( + B, self.grid_size**2, 1).to(x.dtype) # Add positional embeddings - mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) + mask_tokens = mask_tokens + \ + self.positional_embedding.to(mask_tokens.dtype) x = x + self.latent_token_positional_embedding[:L] x = torch.cat([mask_tokens, x], dim=1) diff --git a/models/encoder.py b/models/encoder.py index ae7c9e0..e6b4bac 100644 --- a/models/encoder.py +++ b/models/encoder.py @@ -16,7 +16,7 @@ def __init__(self, grid_size): self.grid_size = grid_size def forward(self, x): - return x[:, self.grid_size**2 :] + return x[:, self.grid_size**2:] class Encoder(nn.Module): @@ -40,12 +40,12 @@ def __init__(self, config): self.token_size = config.model.vq_model.token_size self.width = { - "small": 128, + "small": 512, "base": 768, "large": 1024, }[self.model_size] self.num_layers = { - "small": 1, + "small": 4, "base": 12, "large": 24, }[self.model_size] @@ -84,7 +84,8 @@ def __init__(self, config): # post trasnformer layer norm self.ln_post = nn.LayerNorm(self.width) # project model dim to token dim - self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) + self.conv_out = nn.Conv2d( + self.width, self.token_size, kernel_size=1, bias=True) # encoder model self.model = nn.Sequential( diff --git a/models/loss.py b/models/loss.py index 346dd7a..a22d3ec 100644 --- a/models/loss.py +++ b/models/loss.py @@ -100,7 +100,46 @@ def forward(self, input: torch.Tensor, target: torch.Tensor): return loss -class ReconstructionLoss(nn.Module): +class ReconstructionLoss_Stage1(nn.Module): + def __init__(self, config): + super().__init__() + loss_config = config.losses + self.quantizer_weight = loss_config.quantizer_weight + self.target_codebook_size = 128 # size of proxy codes codebook + + def forward( + self, + target_codes: torch.Tensor, + reconstructions: torch.Tensor, + quantizer_loss: torch.Tensor, + ) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: + reconstructions = reconstructions.contiguous() + reconstructions = reconstructions.contiguous() + loss_fct = nn.CrossEntropyLoss(reduction="mean") + batch_size = reconstructions.shape[0] + reconstruction_loss = loss_fct( + reconstructions.view(batch_size, self.target_codebook_size, -1), + target_codes.view(batch_size, -1), + ) + total_loss = ( + reconstruction_loss + + self.quantizer_weight * quantizer_loss["quantizer_loss"] + ) + + loss_dict = dict( + total_loss=total_loss.clone().detach(), + reconstruction_loss=reconstruction_loss.detach(), + quantizer_loss=( + self.quantizer_weight * quantizer_loss["quantizer_loss"] + ).detach(), + commitment_loss=quantizer_loss["commitment_loss"].detach(), + codebook_loss=quantizer_loss["codebook_loss"].detach(), + ) + + return total_loss, loss_dict + + +class ReconstructionLoss_Stage2(nn.Module): """ Loss for reconstruction network """ diff --git a/models/tokenizer.py b/models/tokenizer.py index e66bddb..062427e 100644 --- a/models/tokenizer.py +++ b/models/tokenizer.py @@ -7,11 +7,108 @@ from torch import nn from .base_model import BaseModel +from .cnn_vqvae import Decoder as PixelDecoder +from .cnn_vqvae import Encoder as PixelEncoder +from .cnn_vqvae import VectorQuantizeEMA from .decoder import Decoder from .encoder import Encoder from .quantizer import VectorQuantizer +class VQVAE(nn.Module): + """ + VQ-VAE Network + Contains: Encoder, Quantizer, Decoder + """ + + def __init__(self, pretrained_weight): + """ + Initialize VQ-VAE encoder decoder network + + :param num_hiddens number: Number of hidden layers + :param num_residual_layers number: Number of residual stacks + :param num_residual_hiddens number: Number of channels in hidden layer + :param num_embeddings number: Number of discrete embeddings + :param embedding_dim number: Dimension of discrete embeddings + :param commitment_cost number: Weight for commitment const in loss + :param decay number: Decay parameter in EMA + """ + num_hiddens = 128 + num_residual_layers = 2 + num_residual_hiddens = 32 + num_embeddings = 128 + embedding_dim = 8 + commitment_cost = 0.25 + decay = 0.99 + + super(VQVAE, self).__init__() + + self.load_state_dict( + torch.load(pretrained_weight, map_location=torch.device("cpu")), strict=True + ) + + self._encoder = PixelEncoder( + 3, num_hiddens, num_residual_layers, num_residual_hiddens + ) + self._pre_vq_conv = nn.Conv2d( + in_channels=num_hiddens, out_channels=embedding_dim, kernel_size=1, stride=1 + ) + + self._vq = VectorQuantizeEMA( + num_embeddings, embedding_dim, commitment_cost, decay + ) + self._decoder = PixelDecoder( + embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens + ) + + self.eval() + for param in self.parameters(): + param.requires_grad(False) + + def set_embeddings(self, new_embeddings): + """ + Set discrete embeddings codebook params + + :param new_embeddings numpy.ndarray: Embedding codebook + """ + with torch.no_grad(): + self._vq._embedding.weight.copy_(new_embeddings) + + def encode(self, x): + """ + Encode image + + :param x numpy.ndarray: Input image + """ + z = self._encoder(x) + z_e = self._pre_vq_conv(z) + return z_e + + def pretrain(self, x): + """ + Bypass vector quantize step for pretraining + + :param x numpy.ndarray: Input image + """ + z = self._encoder(x) + z = self._pre_vq_conv(z) + x_recon = self._decoder(z) + return x_recon + + def forward(self, x): + """ + Encode and reconstruct image + + :param x numpy.ndarray: Input image + """ + z = self._encoder(x) # encode image to latent + z = self._pre_vq_conv(z) + # quantize encoding to dicrete space + loss, z_q, perplexity, _ = self._vq(z) + x_recon = self._decoder(z_q) # reconstruction of input from decoder + return loss, x_recon, perplexity + + class Tokenizer(BaseModel): """ 1D Image Tokenizer @@ -26,6 +123,8 @@ def __init__(self, config): super().__init__() self.config = config + self.finetune_decoder = config.model.vq_model.get("finetune_decoder", True) + self.encoder = Encoder(config) self.decoder = Decoder(config) @@ -43,6 +142,28 @@ def __init__(self, config): commitment_cost=config.model.vq_model.commitment_cost, ) + if self.finetune_decoder: + # Freeze encoder and quantizer gradients + self.latent_tokens.requires_grad(False) + self.encoder.eval() + self.encoder.requires_grad(False) + self.quantizer.eval() + self.quantizer.requires_grad(False) + + self.pixel_quantizer = VectorQuantizeEMA( + n_embeddings=64, + embedding_dim=8, + commitment_cost=0.25, + decay=0.99, + epsilon=0.01, + ) + self.pixel_decoder = PixelDecoder( + in_channels=3, + num_hiddens=128, + num_residual_layers=2, + num_residual_hiddens=32, + ) + def _init_weights(self, module): """Initialize the weights. :param: @@ -72,8 +193,18 @@ def encode(self, x): :param x torch.Tensor: pixel values """ - z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) - z_q, result_dict = self.quantizer(z) + if self.finetune_decoder: + with torch.no_grad(): + self.encoder.eval() + self.quantizer.eval() + z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) + z_q, result_dict = self.quantizer(z) + result_dict["quantizer_loss"] *= 0 + result_dict["commitment_loss"] *= 0 + result_dict["codebook_loss"] *= 0 + else: + z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens) + z_q, result_dict = self.quantizer(z) return z_q, result_dict @@ -83,7 +214,15 @@ def decode(self, z_q): :param z_q torch.Tensor: latent embeddings """ - return self.decoder(z_q) + decoded = self.decoder(z_q) + if self.finetune_decoder: + quantized_states = torch.einsum( + "N C H W, C D -> N D H W", + decoded.softmax(1), + self.pixel_quantizer._embedding.weight, + ) + decoded = self.pixel_decoder(quantized_states) + return decoded def decode_tokens(self, tokens): """