diff --git a/examples/configs/sent140_config.json b/examples/configs/sent140_config.json index 68fc586c..35f07760 100644 --- a/examples/configs/sent140_config.json +++ b/examples/configs/sent140_config.json @@ -15,16 +15,16 @@ } }, "client": { - "epochs": 1, + "epochs": 5, "optimizer": { "_base_": "base_optimizer_sgd", - "lr": 1, + "lr": 0.01, "momentum": 0 } }, "users_per_round": 10, "epochs": 1, - "train_metrics_reported_per_epoch": 10, + "train_metrics_reported_per_epoch": 100, "always_keep_trained_model": false, "eval_epoch_frequency": 1, "do_eval": true, @@ -37,7 +37,7 @@ "model": { "num_classes": 2, "n_hidden": 100, - "dropout_rate": 0.1 + "dropout_rate": 0 } } } diff --git a/examples/sent140_example.py b/examples/sent140_example.py index 82be7678..3fee6904 100644 --- a/examples/sent140_example.py +++ b/examples/sent140_example.py @@ -18,12 +18,8 @@ FedBuff + SGDM python3 sent140_tutorial.py --config-file configs/sent140_fedbuff_config.json """ -import itertools import json import re -import string -import unicodedata -from typing import List import flsim.configs # noqa import hydra # @manual @@ -42,53 +38,43 @@ from torch.utils.data import Dataset -class CharLSTM(nn.Module): +class LSTMModel(nn.Module): def __init__( - self, - num_classes, - n_hidden, - num_embeddings, - embedding_dim, - max_seq_len, - dropout_rate, + self, seq_len, num_classes, embedding_dim, n_hidden, vocab_size, dropout_rate ): - super().__init__() - self.dropout_rate = dropout_rate - self.n_hidden = n_hidden + super(LSTMModel, self).__init__() + self.seq_len = seq_len self.num_classes = num_classes - self.max_seq_len = max_seq_len - self.num_embeddings = num_embeddings + self.n_hidden = n_hidden + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + self.dropout_rate = dropout_rate - self.embedding = nn.Embedding( - num_embeddings=self.num_embeddings, embedding_dim=embedding_dim - ) - self.lstm = nn.LSTM( - input_size=embedding_dim, - hidden_size=self.n_hidden, - num_layers=2, + self.embedding = nn.Embedding(self.vocab_size + 1, self.embedding_dim) + self.stacked_lstm = nn.LSTM( + self.embedding_dim, + self.n_hidden, + 2, batch_first=True, dropout=self.dropout_rate, ) - self.fc = nn.Linear(self.n_hidden, self.num_classes) + self.fc1 = nn.Linear(self.n_hidden, self.num_classes) self.dropout = nn.Dropout(p=self.dropout_rate) + self.out = nn.Linear(128, self.num_classes) - def forward(self, x): - seq_lens = torch.sum(x != (self.num_embeddings - 1), 1) - 1 - x = self.embedding(x) # [B, S] -> [B, S, E] - out, _ = self.lstm(x) # [B, S, E] -> [B, S, H] - out = out[torch.arange(out.size(0)), seq_lens] - out = self.fc(self.dropout(out)) # [B, S, H] -> # [B, S, C] - return out + def forward(self, features): + seq_lens = torch.sum(features != (self.vocab_size - 1), 1) - 1 + x = self.embedding(features) + outputs, _ = self.stacked_lstm(x) + outputs = outputs[torch.arange(outputs.size(0)), seq_lens] + pred = self.fc1(self.dropout(outputs)) + return pred class Sent140Dataset(Dataset): def __init__(self, data_root, max_seq_len): self.data_root = data_root self.max_seq_len = max_seq_len - self.all_letters = {c: i for i, c in enumerate(string.printable)} - self.num_letters = len(self.all_letters) - self.UNK: int = self.num_letters - with open(data_root, "r+") as f: self.dataset = json.load(f) @@ -96,6 +82,8 @@ def __init__(self, data_root, max_seq_len): self.targets = {} self.num_classes = 2 + self.word2id = self.build_vocab() + self.vocab_size = len(self.word2id) # Populate self.data and self.targets for user_id, user_data in self.dataset["user_data"].items(): @@ -115,25 +103,16 @@ def __getitem__(self, user_id: str): return self.data[user_id], self.targets[user_id] - def unicodeToAscii(self, s): - return "".join( - c - for c in unicodedata.normalize("NFD", s) - if unicodedata.category(c) != "Mn" and c in self.all_letters - ) - - def line_to_indices(self, line: str, max_seq_len: int): - line_list = self.split_line(line) # split phrase in words - line_list = line_list - chars = self.flatten_list([list(word) for word in line_list]) - # padding - indices: List[int] = [ - self.all_letters.get(letter, self.UNK) - for i, letter in enumerate(chars) - if i < max_seq_len - ] - indices = indices + ([self.UNK] * (max_seq_len - len(indices))) - return indices + def build_vocab(self): + word2id = {} + for user_data in self.dataset["user_data"].values(): + lines = [e[4] for e in user_data["x"]] + for line in lines: + line_list = self.split_line(line) + for word in line_list: + if word not in word2id: + word2id[word] = len(word2id) + return word2id def process_x(self, raw_x_batch): x_batch = [e[4] for e in raw_x_batch] @@ -145,18 +124,24 @@ def process_y(self, raw_y_batch): y_batch = [int(e) for e in raw_y_batch] return y_batch - def split_line(self, line): - """split given line/phrase into list of words - Args: - line: string representing phrase to be split + def line_to_indices(self, line, max_words=25): + unk_id = len(self.word2id) + line_list = self.split_line(line) # split phrase in words + indl = [ + self.word2id[w] if w in self.word2id else unk_id + for w in line_list[:max_words] + ] + indl += [unk_id - 1] * (max_words - len(indl)) + return indl - Return: - list of strings, with each string representing a word - """ - return re.findall(r"[\w']+|[.,!?;]", line) + def val_to_vec(self, size, val): + assert 0 <= val < size + vec = [0 for _ in range(size)] + vec[int(val)] = 1 + return vec - def flatten_list(self, nested_list): - return list(itertools.chain.from_iterable(nested_list)) + def split_line(self, line): + return re.findall(r"[\w']+|[.,!?;]", line) def build_data_provider(data_config, drop_last=False): @@ -179,7 +164,7 @@ def build_data_provider(data_config, drop_last=False): ) data_provider = DataProvider(dataloader) - return data_provider, train_dataset.num_letters + return data_provider, train_dataset.vocab_size def main_worker( @@ -189,16 +174,17 @@ def main_worker( use_cuda_if_available=True, distributed_world_size=1, ): - data_provider, num_letters = build_data_provider(data_config) + data_provider, vocab_size = build_data_provider(data_config) - model = CharLSTM( + model = LSTMModel( num_classes=model_config.num_classes, n_hidden=model_config.n_hidden, - num_embeddings=num_letters + 1, - embedding_dim=100, - max_seq_len=data_config.max_seq_len, + vocab_size=vocab_size, + embedding_dim=50, + seq_len=data_config.max_seq_len, dropout_rate=model_config.dropout_rate, ) + cuda_enabled = torch.cuda.is_available() and use_cuda_if_available device = torch.device(f"cuda:{0}" if cuda_enabled else "cpu") global_model = FLModel(model, device)