-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathembedding.py
More file actions
28 lines (24 loc) · 950 Bytes
/
embedding.py
File metadata and controls
28 lines (24 loc) · 950 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
from torch import nn
class ConditionalEmbedding(nn.Module):
"""
A conditional embedding module for incorporating label information into a model.
Args:
num_labels (int): Number of unique labels.
d_model (int): Dimensionality of the model's embedding space.
dim (int): Output dimensionality of the conditional embedding.
Attributes:
condEmbedding (nn.Sequential): Sequential layers for conditional embedding.
"""
def __init__(self, num_labels:int, d_model:int, dim:int):
assert d_model % 2 == 0
super().__init__()
self.condEmbedding = nn.Sequential(
nn.Embedding(num_embeddings=num_labels, embedding_dim=d_model, padding_idx=0),
nn.Linear(d_model, dim),
nn.SiLU(),
nn.Linear(dim, dim),
)
def forward(self, t:torch.Tensor) -> torch.Tensor:
emb = self.condEmbedding(t)
return emb