From 21ad819a9742fdd31674d90b04bf81dc74c64f6a Mon Sep 17 00:00:00 2001 From: "e2b-for-github[bot]" <134465507+e2b-for-github[bot]@users.noreply.github.com> Date: Thu, 6 Jul 2023 00:17:24 +0000 Subject: [PATCH 1/2] Initial commit From 25cfc3aaee145fc29b9f84d1ade44a65cb9d0fb8 Mon Sep 17 00:00:00 2001 From: "e2b-for-github[bot]" <337977+e2b-for-github[bot]@users.noreply.github.com> Date: Thu, 6 Jul 2023 00:20:28 +0000 Subject: [PATCH 2/2] Add code based on the PR instructions --- .gitignore | 130 ------- README.md | 6 - __init__.py | 0 bart_rl.py | 201 ----------- clf_sst2.py | 71 ---- data.py | 158 --------- deprecated/train.py | 198 ----------- deprecated/train_1014.py | 208 ----------- notebooks/Diversity exploration.ipynb | 492 -------------------------- package.json | 47 +++ public/index.html | 12 + requirements.txt | 37 -- reward.py | 118 ------ shared_dependencies.md | 23 ++ src/App.tsx | 22 ++ src/components/Login.tsx | 44 +++ src/components/Logout.tsx | 27 ++ src/components/ProtectedRoute.tsx | 24 ++ src/components/SignUp.tsx | 48 +++ src/index.tsx | 14 + src/services/auth.ts | 41 +++ src/styles/global.css | 26 ++ src/styles/login.css | 42 +++ src/styles/logout.css | 21 ++ src/styles/signup.css | 42 +++ src/types/user.ts | 6 + src/utils/firebase.ts | 15 + train_bart_v2.py | 151 -------- tsconfig.json | 25 ++ utils.py | 54 --- 30 files changed, 479 insertions(+), 1824 deletions(-) delete mode 100644 .gitignore delete mode 100644 README.md delete mode 100644 __init__.py delete mode 100644 bart_rl.py delete mode 100644 clf_sst2.py delete mode 100644 data.py delete mode 100644 deprecated/train.py delete mode 100644 deprecated/train_1014.py delete mode 100644 notebooks/Diversity exploration.ipynb create mode 100755 package.json create mode 100755 public/index.html delete mode 100644 requirements.txt delete mode 100644 reward.py create mode 100755 shared_dependencies.md create mode 100755 src/App.tsx create mode 100755 src/components/Login.tsx create mode 100755 src/components/Logout.tsx create mode 100755 src/components/ProtectedRoute.tsx create mode 100755 src/components/SignUp.tsx create mode 100755 src/index.tsx create mode 100755 src/services/auth.ts create mode 100755 src/styles/global.css create mode 100755 src/styles/login.css create mode 100755 src/styles/logout.css create mode 100755 src/styles/signup.css create mode 100755 src/types/user.ts create mode 100755 src/utils/firebase.ts delete mode 100644 train_bart_v2.py create mode 100755 tsconfig.json delete mode 100644 utils.py diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 20ad5d8..0000000 --- a/.gitignore +++ /dev/null @@ -1,130 +0,0 @@ -##################### PYTHON GIT IGNORES ################################# -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -pip-wheel-metadata/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -.python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ diff --git a/README.md b/README.md deleted file mode 100644 index c12f11c..0000000 --- a/README.md +++ /dev/null @@ -1,6 +0,0 @@ -# DiversityDataAugmentation - -## Data -- Download IMDB data [here](https://www.kaggle.com/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews) -- `python -m spacy download en` following torchtext tokenization [tutorial](https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html) -- Download SST-2 diff --git a/__init__.py b/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/bart_rl.py b/bart_rl.py deleted file mode 100644 index 8224794..0000000 --- a/bart_rl.py +++ /dev/null @@ -1,201 +0,0 @@ -from transformers import ( - MaxLengthCriteria, - TemperatureLogitsWarper, - TopKLogitsWarper, - MinLengthLogitsProcessor, - LogitsProcessorList -) -from transformers import ( - BartTokenizerFast, BartForConditionalGeneration -) -import torch -from utils import DEV, LOG_EPS, MODEL_KEY - - -def load_bart_model(layers=None): - ''' - Load pretrained BartForConditionalGeneration - ''' - config = dict() - if layers: - assert type(layers) == int - config = dict(encoder_layers=layers, decoder_layers=layers) - model = BartForConditionalGeneration.from_pretrained(MODEL_KEY, **config) - return model - -def load_bart_tokenizer(): - ''' - Load pretrained bart tokenizer - ''' - return BartTokenizerFast.from_pretrained(MODEL_KEY) - -class BartReinforce(): - ''' - BART for RL - specifically policy gradient with REINFORCE - ''' - def __init__(self, model, device): - self.model = model - self.model = self.model.to(device) - self.device = device - # Contains actions from latest batch. List of token_id LongTensors - self.actions = list() - # Contains log_probs of actions from latest batch. FloatTensor(batch_size, num_steps) of log probs - self.log_probs = None - # if MULTI_GPU: - # self.parallel_forward = torch.nn.DataParallel(self.model, device_ids=device_ids).to(dev) - # else: - self.pad_token_id = self.model.config.pad_token_id - self.eos_token_id = self.model.config.eos_token_id - self.bos_token_id = self.model.config.bos_token_id - # This is 2 which equals , the eos_token_id - self.decoder_start_token_id = self.model.config.decoder_start_token_id - self.is_encoder_decoder = self.model.config.is_encoder_decoder - - @property - def encoder(self): - ''' - Getter for encoder - ''' - return self.model.model.encoder - - def freeze_encoder_params(self): - ''' - Freeze encoder params from updates - ''' - for layer in self.encoder.parameters(): - layer.requires_grad= False - - def clear_episode_batch(self): - ''' - Clear actions and log probs from last batch of episodes - ''' - self.actions = list() - self.log_probs = None - - def sample_policy(self, probs): - ''' - Epsilon-greedy sampling from softmax distribution - - Args: - probs :torch.FloatTensor of shape (batch_size, vocab_size): Softmax probs for token - ''' - # epsilon-greedy (note this goes across batch), use uniform probs - if torch.rand(1).item() < self.epsilon: - # filter out 0 probability tokens - num_nonzero = (probs != 0).sum().item() - sample_probs = torch.ones(probs.shape)/num_nonzero - sample_probs[probs == 0] = 0. - # use policy probs - else: - sample_probs = probs - return torch.distributions.Categorical(sample_probs).sample() - - def run_step(self, probs, unfinished_sequences): - ''' - Sample next tokens for batch and store actions and log probabilities - ''' - next_tokens = self.sample_policy(probs).to(self.device) - next_tokens = next_tokens * unfinished_sequences + self.pad_token_id * (1 - unfinished_sequences) - probs = probs + LOG_EPS - selected_log_probs = torch.log(probs.gather(1, next_tokens.unsqueeze(-1))) - self.actions.append(next_tokens.cpu()) - if self.log_probs is None: - self.log_probs = selected_log_probs - else: - self.log_probs = torch.cat((self.log_probs, selected_log_probs), 1) - return next_tokens - - def prepare_inputs_for_decoder(self, input_ids, model_kwargs): - ''' - Run encoder and set up decoder input ids - - Returns: - :torch.LongTensor of shape (batch_size, vocab_size): Decoder input ids - :dict: Updated model_kwargs with encoder_outputs - ''' - # Should be True for BART models. Run encoder and set up decoder inputs - if self.is_encoder_decoder: - encoder_input_ids, attention_mask = input_ids, model_kwargs['attention_mask'] - # Get encoder outputs of type BaseModelOutput - self.encoder.to(self.device) - model_kwargs['encoder_outputs'] = self.encoder(input_ids=encoder_input_ids, attention_mask=attention_mask) - model_kwargs = self.model._prepare_encoder_decoder_kwargs_for_generation(input_ids, model_kwargs) - # Set input_ids as decoder_input_ids - if "decoder_input_ids" in model_kwargs: - input_ids = model_kwargs.pop("decoder_input_ids") - else: - input_ids = self.model._prepare_decoder_input_ids_for_generation( - input_ids, decoder_start_token_id=self.decoder_start_token_id, bos_token_id=self.bos_token_id - ) - if "encoder_outputs" not in model_kwargs: - raise ValueError("Make sure that `model_kwargs` include `encoder_outputs`.") - return input_ids, model_kwargs - - - def generate_episodes(self, batch, min_length=0, max_length=None, temperature=1.0, epsilon=0.001, topk=500, verbose=False): - ''' - Generate episodes over batch of sequences - - Args: - batch :List[torch.Tensor]: Contains the below elements (in order) - input_ids :shape (batch_size, seq_length): Input sequence for generation. - attention_mask :shape (batch_size, seq_length): Attention mask. - labels :shape (batch_size, ): Label for each sequence. - max_length :int: Max output sequence length - temperature :float: Rescale logits before softmax by `logits = logits/temperature`. Higher temperatures t result in softer probability distribution. which goes to uniform as t->infinity. - ''' - # Set epsilon for this batch - self.epsilon = epsilon - model_kwargs = dict() - input_ids, attention_mask, labels = batch - input_ids, attention_mask, labels = input_ids.to(self.device), attention_mask.to(self.device), labels.to(self.device) - model_kwargs['attention_mask'] = attention_mask - input_ids, model_kwargs = self.prepare_inputs_for_decoder(input_ids, model_kwargs) - max_length = max_length if max_length is not None else self.model.config.max_length - # For setting sequence length limit - stopping_criteria = MaxLengthCriteria(max_length) - # Get distribution pre_processing samplers - logits_warper = LogitsProcessorList() - if temperature > 0: - logits_warper.append(TemperatureLogitsWarper(temperature)) - if topk > 0: - logits_warper.append(TopKLogitsWarper(topk)) - if min_length > 0: - logits_warper.append(MinLengthLogitsProcessor(min_length, self.eos_token_id)) - ## Generation ## - # Keep track of which sequences are already finished - ### - # Initially unfinished_sequences = sequence of 1s with length batch size - # and cur_len = tensor of shape (batch_size, 1) containing decoder_start_token_id. - #### - unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) - cur_len = input_ids.shape[-1] - i = 0 - while True: - i += 1 - if verbose: - print("Step", i) - ## Run decoder for one time step ## - # Dictionary with masks, decoder input ids, and encoder outputs - model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) - # Seq2SeqLMOutput - outputs = self.model(**model_inputs, return_dict=True) - # Logits of shape (batch_size, 1, vocab_size) -> (batch_size, vocab_size). See top k with torch.topk - next_token_scores = logits_warper(input_ids, outputs.logits[:, -1, :]) - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = self.run_step(probs, unfinished_sequences) - ## Update for next step ## - # append next tokens - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - # update past to past_key_values from outputs, attention mask should be same as from model_inputs - model_kwargs = self.model._update_model_kwargs_for_generation( - outputs, model_kwargs, is_encoder_decoder=self.is_encoder_decoder - ) - # update length - cur_len = cur_len + 1 - # If eos_token was found in one sentence, set sentence to finished - unfinished_sequences = unfinished_sequences.mul((next_tokens != self.eos_token_id).long()) - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, None): - break - return input_ids \ No newline at end of file diff --git a/clf_sst2.py b/clf_sst2.py deleted file mode 100644 index 25cfd8c..0000000 --- a/clf_sst2.py +++ /dev/null @@ -1,71 +0,0 @@ -from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification -import utils -from data import SSTLoader, TokenizerWrapper -import torch -from torch import nn -import numpy as np -from tqdm import tqdm -from typing import Union, List - -class DistilBertSST(nn.Module): - ''' - Get finetuned checkpoint of distilbert on sst2 - ''' - model_key = 'distilbert-base-uncased-finetuned-sst-2-english' - def __init__(self, device=utils.DEV): - super().__init__() - self.device = device - self.model = DistilBertForSequenceClassification.from_pretrained(self.model_key).to(self.device) - self.tokenizer = TokenizerWrapper( - DistilBertTokenizerFast.from_pretrained(self.model_key), - { - 'return_tensors': 'pt', - 'padding': True, - 'truncation': True - } - ) - - def forward(self, *args, **kwargs): - return self.model(*args, **kwargs) - - @torch.no_grad() - def predict_on_text(self, text: Union[str, List[str]]) -> np.ndarray: - ''' - Get predicted labels from applying model to text or texts. - Let N := number of input texts and C := the number of classes - - Returns: - :np.ndarray of shape (N, ): Predicted labels - :np.ndarray of shape (N, C): Probabilities for each class - ''' - self.model.eval() - encodings = self.tokenizer.encode(text) - outputs = self(encodings['input_ids'].to(self.device), encodings['attention_mask'].to(self.device)) - return ( - np.argmax(utils.convert_to_numpy(outputs.logits), axis=1).flatten(), - nn.functional.softmax(outputs.logits, dim=-1) - ) - -def run_validate_model(): - model = DistilBertSST() - model.to(self.device) - sst2 = SSTLoader(model.tokenizer, batch_size=128, lim=-1) - train_loader, val_loader, test_loader = sst2.get_train_loader(), sst2.get_val_loader(), sst2.get_test_loader() - - print(f"Testing on {len(val_loader)*val_loader.batch_size} inputs") - model.eval() - with tqdm(val_loader, unit="batch") as pbar: - for batch in pbar: - accs = [] - with torch.no_grad(): - batch = [data.to(self.device) for data in batch] - outputs = model(batch[0], batch[1], labels=batch[2]) - accs.append(utils.flat_accuracy(outputs.logits, batch[2])) - pbar.set_description(f"Mean Accuracy so far = {np.mean(accs)}") - print("Final Accuracy = ", np.mean(accs)) - -if __name__ == '__main__': - run_validate_model() - - - diff --git a/data.py b/data.py deleted file mode 100644 index ff236d7..0000000 --- a/data.py +++ /dev/null @@ -1,158 +0,0 @@ -import torch -import torch.utils.data as torch_data -from typing import List, Dict, Union -import datasets - -class TokenizerWrapper(): - ''' - Wrapper for tokenizer - ''' - default_encode_config = dict( - add_special_tokens = True, - padding=True, truncation=True, - return_tensors='pt' - ) - def __init__(self, tokenizer, encode_config=default_encode_config): - self.t = tokenizer - self.encode_config = encode_config - - @property - def mask_token_id(self): - return self.t.mask_token_id - - def encode(self, sentences: Union[str, List[str]]) -> Dict[str, torch.Tensor]: - ''' - Return tokenized sentences - - Returns - encodings: dict with 'input_ids' and 'attention_mask' - ''' - return self.t(sentences, **self.encode_config) - - def decode(self, encodings: torch.LongTensor, skip_special_tokens=True) -> List[str]: - ''' - Return decoded sentences from token id sequences - ''' - def decode_id_seq(s): - decoded = self.t.decode(s, skip_special_tokens=skip_special_tokens) - try: - return bytes(decoded, 'utf8').decode('latin1', 'ignore') - except UnicodeEncodeError as e: - import pdb; pdb.set_trace() - _ = 1 - return [decode_id_seq(e) for e in encodings] - -class SSTLoader(): - ''' - Data loading for sst data. Does encoding etc - - Params: - lim: Use lim of -1 to use all samples. Otherwise uses up to lim samples. - batch_size: Batch size for loaders - tokenizer: Tokenizer - - Returns: - batches from TensorDataset with elements: input_ids, attention_mask, labels - - Dataset reference: - DatasetDict({ - train: Dataset({ - features: ['sentence', 'label', 'idx'], - num_rows: 67349 - }) - validation: Dataset({ - features: ['sentence', 'label', 'idx'], - num_rows: 872 - }) - test: Dataset({ - features: ['sentence', 'label', 'idx'], - num_rows: 1821 - }) - }) - Features: - { - 'sentence': Value(dtype='string', id=None), - 'label': ClassLabel(num_classes=2, names=['negative', 'positive'], - 'idx': Value(dtype='int32', id=None) - } - ''' - def __init__(self, tokenizer: TokenizerWrapper = None, batch_size: int = 8, lim: int = -1): - self.lim = lim - self.tokenizer = tokenizer - self.batch_size = batch_size - self.__load_sst_binary() - - - - def __load_sst_binary(self): - ''' - Set sst train, val and test data - ''' - # Training data from glue (not tokenized) containing keys ('sentence', 'idx', 'label') - raw_dss = datasets.load_dataset("sst") - dss = [raw_dss['train'], raw_dss['validation'], raw_dss['test']] - for i, ds in enumerate(dss): - dss[i] = self.preprocess_dataset(ds) - self.train_dataset, self.val_dataset, self.test_dataset = dss - - def __create_torch_dataloader(self, sents, labels, shuffle) -> torch_data.DataLoader: - encodings = self.tokenizer.encode(sents) - torch_ds = torch_data.TensorDataset( - encodings['input_ids'], - encodings['attention_mask'], - labels - ) - return torch_data.DataLoader(torch_ds, batch_size=self.batch_size, shuffle=shuffle) - - def __sst_to_loader(self, d, s): - return self.__create_torch_dataloader(d['sentence'], torch.as_tensor(d['label'], dtype=torch.long), s) - - @staticmethod - def __sentiment_to_binary(example): - example['label'] = round(example['label']) - return example - - def __preprocess_example_sents(self, example): - example['sentence'] = self.preprocess_sentence(example['sentence']) - example['sentence'] = example['sentence'] + " " + str(int(example['label'])) - return example - - @staticmethod - def preprocess_sentence(s: str): - # remove_punc = "()-[]{};:\",<>/@#$%^&*_~`" - # s = s.lower().strip() - # s = ''.join([c for c in s if c not in remove_punc]) - return s - - def preprocess_dataset(self, ds: datasets.Dataset) -> datasets.Dataset: - if self.lim > 0: - ds = ds.select(range(self.lim)) - ds = ds.map(self.__sentiment_to_binary) - ds = ds.map(self.__preprocess_example_sents) - return ds - - def get_train_loader(self, shuffle=True): - ''' - Encodes dataset and return train dataloader (data batches) - ''' - return self.__sst_to_loader(self.train_dataset, shuffle) - - def get_val_loader(self, shuffle=True): - ''' - Encodes dataset and return val dataloader (data batches) - ''' - return self.__sst_to_loader(self.val_dataset, shuffle) - - def get_test_loader(self, shuffle=True): - ''' - Encodes dataset and return test dataloader (data batches) - ''' - return self.__sst_to_loader(self.test_dataset, shuffle) - -if __name__ == '__main__': - from bart_rl import load_bart_tokenizer - sst2 = SSTLoader(TokenizerWrapper(load_bart_tokenizer()), lim=100) - # train_loader = sst2.get_train_loader() - - - diff --git a/deprecated/train.py b/deprecated/train.py deleted file mode 100644 index a9643d8..0000000 --- a/deprecated/train.py +++ /dev/null @@ -1,198 +0,0 @@ -import torch -from torch import optim -import transformers -from transformers import (BartTokenizerFast, PreTrainedTokenizerFast, BartModel, BartForConditionalGeneration, - LogitsProcessorList, MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria, - AutoModelForSeq2SeqLM) -from transformers.generation_utils import GenerationMixin -import datasets -from fuzzywuzzy import fuzz -gm = GenerationMixin - -import pdb -if torch.cuda.is_available(): - dev = "cuda:0" -else: - dev = "cpu" - -print(f'Using device {dev}') -# MODEL_KEY = 'sshleifer/distilbart-cnn-12-3' -MODEL_KEY = 'facebook/bart-base' - -def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): - """ - Shift input ids one token to the right. - """ - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() - shifted_input_ids[:, 0] = decoder_start_token_id - - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - -def load_sst(start=0, end=10): - return datasets.load_dataset('glue', 'sst2', split=f'train[{start}:{end}]') - -def load_model(): - layers = 2 # default is 12 - config = dict(encoder_layers=layers, decoder_layers=layers) - - model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_KEY, **config) - model = model.to(dev) - return model - -def tokenize_sentences(sentences): - tokenizer = BartTokenizerFast.from_pretrained(MODEL_KEY) - ftok = lambda z: tokenizer(z, truncation=True, padding='longest', return_tensors='pt') - tokenized = ftok([s for s in sentences]) - tokenized.input_ids = tokenized.input_ids.to(dev) - tokenized.attention_mask = tokenized.attention_mask.to(dev) - - return tokenized, tokenizer - -def reward_fuzz_match(s_in, s_out): - ''' - Reward based on fuzzy match ratio. Encourage similarity (toy example) - ''' - r = fuzz.ratio(s_in, s_out) / 100 - return r - -def reward_matching_tokens(s_in, s_out): - ''' - Reward based on number of matching words - ''' - t1 = set(s_in.split()) - t2= set(s_out.split()) - return len(t1 & t2) / len(t1 | t2) - -def get_inputs(model, input_ids): - decoder_start_token_id = model.config.decoder_start_token_id - bos_token_id = model.config.bos_token_id - model_kwargs = dict() - # prepare attention mask and encoder output - model_kwargs["attention_mask"] = gm._prepare_attention_mask_for_generation( - model, input_ids, pad_token_id, eos_token_id) - encoder_input_ids = input_ids if model.config.is_encoder_decoder else None - if model.config.is_encoder_decoder: - model_kwargs = gm._prepare_encoder_decoder_kwargs_for_generation(model, input_ids, model_kwargs) - - input_ids = gm.i( - model, input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id) - - model_kwargs["use_cache"] = None - - logits_processor = gm._get_logits_processor( - model, - repetition_penalty=None, - bad_words_ids=None, - min_length=10, - max_length=11, - eos_token_id=None, - prefix_allowed_tokens_fn=None, - num_beam_groups=None, - diversity_penalty=None, - no_repeat_ngram_size=None, - encoder_no_repeat_ngram_size=None, - encoder_input_ids=encoder_input_ids, - forced_bos_token_id=None, - forced_eos_token_id=None, - num_beams=None, - remove_invalid_values=True) - return input_ids, logits_processor, model_kwargs - -def compute_sequence_score(sequence_ids, sequence_scores): - ''' - sequence_ids: sequence token ids of shape (max_length, ) - sequence_scores: sequence_scores of shape (max_length - 1, vocab_size) containing pre-softmax scores - ''' - sequence_scores = torch.log_softmax(sequence_scores, 1) - policy_scores = [] - for i, id in enumerate(sequence_ids[1:]): - # get score for chosen action i.e which token was generated - score = sequence_scores[i][id] - policy_scores.append(score) - # We should have a score for each token in the sequence - return torch.tensor(policy_scores, requires_grad=True, device=dev).sum() - -def decode_sentences(sequences, tokenizer): - gen_sentences = [tokenizer.decode(s, skip_special_tokens=True).encode('utf-8') for s in outputs['sequences']] - return gen_sentences - - -if __name__ == '__main__': - LR = 1e-3 - USE_AMS = False - EPOCHS = 100 - interval = 10 - TEST = False - sst_dataset = load_sst(end=1) - if TEST: - test_sents = ['test', 'testing', 'test z', 'test y', 'test w'] - encodings, tokenizer = tokenize_sentences(['test', 'testing', 'test z', 'test y', 'test w']) - else: - encodings, tokenizer = tokenize_sentences(sst_dataset['sentence']) - model = load_model() - pad_token_id = model.config.pad_token_id - eos_token_id = model.config.eos_token_id - if TEST: - batches = [(list(range(len(test_sents))), test_sents, encodings)] - else: - batches = [(sst_dataset['idx'], sst_dataset['sentence'], encodings)] - optimizer = optim.Adam(model.parameters(), lr=LR, amsgrad=USE_AMS) - for epoch in range(EPOCHS): - for b in batches: - optimizer.zero_grad() - indices, sentences, e = b - # inputs = get_inputs(model, e.input_ids) - # input_ids, logits_processor, model_kwargs = inputs - ### SAMPLE FOR RL ### - # outputs = gm.sample( - # model, - # input_ids, - # logits_processor=logits_processor, - # pad_token_id=pad_token_id, - # eos_token_id=eos_token_id, - # output_scores=True, - # return_dict_in_generate=True, - # **model_kwargs) - - - # pdb.set_trace() - encoder_output = model.model.encoder(input_ids=e.input_ids, attention_mask=e.attention_mask) - decoder_input_ids = model._prepare_decoder_input_ids_for_generation(e.input_ids) - outputs = model.sample(decoder_input_ids, encoder_outputs=encoder_output, stopping_criteria=MaxLengthCriteria(20), output_scores=True, return_dict_in_generate=True) - # Decode sentences and compute losses - gen_sentences = decode_sentences(outputs['sequences'], tokenizer) - # Get generated sequences of ids and reshape for selecting log probs corresponding to actions - logits = torch.stack(outputs['scores'], dim=0) - log_probs = torch.log_softmax(logits.squeeze(), dim=1) - seq = outputs['sequences'].flatten()[1:].unsqueeze(1) - selected_probs=torch.gather(log_probs, 1, seq) - # Compute reward - rewards = [] - for s_in, s_out in zip(sentences, gen_sentences): - rewards.append(reward_matching_tokens(s_in, s_out.decode('utf-8'))) - rewards = torch.tensor(rewards, requires_grad=False, device=dev) - rewards = rewards - # Compute loss - loss = (rewards * -selected_probs).mean() - ### SUPERVISED COPY #### - # if epoch % interval == 0: - # gen_sentences = model.generate(input_ids=e.input_ids, attention_mask=e.attention_mask) - # gen_sentences = [tokenizer.decode(s, skip_special_tokens=True).encode('utf-8') for s in gen_sentences] - # loss = model(input_ids=e.input_ids, labels=e.input_ids).loss - - loss.backward() - optimizer.step() - - if epoch % interval == 0: - print(f"--------------------------EPOCH {epoch}--------------------------") - print("REWARDS:", rewards) - print("SELECTED_PROBS:",selected_probs) - print("LOSS:", loss) - print("GENERATED:", gen_sentences) - print("TARGETS:", sentences) - # print(log_probs, list(model.named_parameters())[:3]) diff --git a/deprecated/train_1014.py b/deprecated/train_1014.py deleted file mode 100644 index ad27ac5..0000000 --- a/deprecated/train_1014.py +++ /dev/null @@ -1,208 +0,0 @@ -import torch -from torch import optim -import transformers -from transformers import (BartTokenizerFast, PreTrainedTokenizerFast, BartModel, BartForConditionalGeneration, - LogitsProcessorList, MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria, - AutoModelForSeq2SeqLM) -from transformers.generation_utils import GenerationMixin -import datasets -from fuzzywuzzy import fuzz -gm = GenerationMixin - -import pdb -if torch.cuda.is_available(): - dev = "cuda:0" -else: - dev = "cpu" -dev = "cpu" - -print(f'Using device {dev}') -# MODEL_KEY = 'sshleifer/distilbart-cnn-12-3' -MODEL_KEY = 'facebook/bart-base' - -def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): - """ - Shift input ids one token to the right. - """ - shifted_input_ids = input_ids.new_zeros(input_ids.shape) - shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() - shifted_input_ids[:, 0] = decoder_start_token_id - - assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." - # replace possible -100 values in labels by `pad_token_id` - shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) - - return shifted_input_ids - -def load_sst(start=0, end=10): - return datasets.load_dataset('glue', 'sst2', split=f'train[{start}:{end}]') - -def load_model(): - layers = 2 # default is 12 - config = dict(encoder_layers=layers, decoder_layers=layers) - - model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_KEY, **config) - model = model.to(dev) - return model - -def tokenize_sentences(sentences): - tokenizer = BartTokenizerFast.from_pretrained(MODEL_KEY) - ftok = lambda z: tokenizer(z, truncation=True, padding='longest', return_tensors='pt') - tokenized = ftok([s for s in sentences]) - tokenized.input_ids = tokenized.input_ids.to(dev) - tokenized.attention_mask = tokenized.attention_mask.to(dev) - - return tokenized, tokenizer - -def reward_match(s_in, s_out): - ''' - Reward based on fuzzy match ratio. Encourage similarity (toy example) - ''' - r = fuzz.ratio(s_in, s_out) / 100 - return r - -def get_inputs(model, input_ids): - decoder_start_token_id = model.config.decoder_start_token_id - bos_token_id = model.config.bos_token_id - model_kwargs = dict() - # prepare attention mask and encoder output - model_kwargs["attention_mask"] = gm._prepare_attention_mask_for_generation( - model, input_ids, pad_token_id, eos_token_id) - encoder_input_ids = input_ids if model.config.is_encoder_decoder else None - if model.config.is_encoder_decoder: - model_kwargs = gm._prepare_encoder_decoder_kwargs_for_generation(model, input_ids, model_kwargs) - - input_ids = gm._prepare_decoder_input_ids_for_generation( - model, input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id) - - model_kwargs["use_cache"] = None - - logits_processor = gm._get_logits_processor( - model, - repetition_penalty=None, - bad_words_ids=None, - min_length=10, - max_length=11, - eos_token_id=None, - prefix_allowed_tokens_fn=None, - num_beam_groups=None, - diversity_penalty=None, - no_repeat_ngram_size=None, - encoder_no_repeat_ngram_size=None, - encoder_input_ids=encoder_input_ids, - forced_bos_token_id=None, - forced_eos_token_id=None, - num_beams=None, - remove_invalid_values=True) - return input_ids, logits_processor, model_kwargs - -def compute_sequence_score(sequence_ids, sequence_scores): - ''' - sequence_ids: sequence token ids of shape (max_length, ) - sequence_scores: sequence_scores of shape (max_length - 1, vocab_size) containing pre-softmax scores - ''' - sequence_scores = torch.log_softmax(sequence_scores, 1) - policy_scores = [] - for i, id in enumerate(sequence_ids[1:]): - # get score for chosen action i.e which token was generated - score = sequence_scores[i][id] - policy_scores.append(score) - # We should have a score for each token in the sequence - return torch.tensor(policy_scores, requires_grad=True, device=dev).sum() - - - - -if __name__ == '__main__': - LR = 1e-3 - USE_AMS = False - EPOCHS = 100 - interval = 10 - TEST = False - sst_dataset = load_sst(end=1) - if TEST: - test_sents = ['test', 'testing', 'test z', 'test y', 'test w'] - encodings, tokenizer = tokenize_sentences(['test', 'testing', 'test z', 'test y', 'test w']) - else: - encodings, tokenizer = tokenize_sentences(sst_dataset['sentence']) - model = load_model() - pad_token_id = model.config.pad_token_id - eos_token_id = model.config.eos_token_id - if TEST: - batches = [(list(range(len(test_sents))), test_sents, encodings)] - else: - batches = [(sst_dataset['idx'], sst_dataset['sentence'], encodings)] - optimizer = optim.Adam(model.parameters(), lr=LR, amsgrad=USE_AMS) - for epoch in range(EPOCHS): - print(epoch) - for b in batches: - optimizer.zero_grad() - indices, sentences, e = b - # inputs = get_inputs(model, e.input_ids) - # input_ids, logits_processor, model_kwargs = inputs - ### SAMPLE FOR RL ### - # outputs = gm.sample( - # model, - # input_ids, - # logits_processor=logits_processor, - # pad_token_id=pad_token_id, - # eos_token_id=eos_token_id, - # output_scores=True, - # return_dict_in_generate=True, - # **model_kwargs) - - - # pdb.set_trace() - encoder_output = model.model.encoder(input_ids=e.input_ids, attention_mask=e.attention_mask) - decoder_input_ids = model.prepare_decoder_input_ids_from_labels(e.input_ids) - outputs = model.sample(decoder_input_ids, encoder_outputs=encoder_output, stopping_criteria=MaxLengthCriteria(30), output_scores=True, return_dict_in_generate=True) - # Decode sentences and compute losses - gen_sentences = [tokenizer.decode(s, skip_special_tokens=True).encode('utf-8') for s in outputs['sequences']] - # get generated sequences of ids and reshape for selecting log probs corresponding to actions - logits = torch.stack(outputs['scores'], dim=0) - log_probs = torch.log_softmax(logits.squeeze(), dim=1) - - # seq = outputs['sequences'][:,1:] - # seq = seq.reshape((seq.shape[1], -1)) - # selected_probs=torch.gather(log_probs, 1, seq) - rewards = [] - for s_in, s_out in zip(sentences, gen_sentences): - rewards.append(reward_match(s_in, s_out.decode('utf-8'))) - rewards = torch.tensor(rewards, requires_grad=False, device=dev) - rewards = rewards - # loss = (rewards * -selected_probs).mean() - # labels = e.input_ids[:,1:] - # # PADDING - padding = torch.tensor([tokenizer.pad_token_id]*(len(logits)-e.input_ids.shape[1])).unsqueeze(0).to(dev) - labels = torch.cat((e.input_ids, padding), 1) - # # PADDING - padding2 = torch.tensor([tokenizer.pad_token_id]*(labels.shape[1]-len(logits))).unsqueeze(0).to(dev) - # pdb.set_trace() - if labels is not None: - loss_fct = torch.nn.CrossEntropyLoss() - z1, z2 = logits.squeeze(1), labels.flatten().unsqueeze(1) - yhat = torch.gather(z1, 1, z2) - try: - masked_lm_loss = loss_fct(yhat, labels) - except: - pdb.set_trace() - loss = masked_lm_loss - - - ### SUPERVISED COPY #### - # if epoch % interval == 0: - # gen_sentences = model.generate(input_ids=e.input_ids, attention_mask=e.attention_mask) - # gen_sentences = [tokenizer.decode(s, skip_special_tokens=True).encode('utf-8') for s in gen_sentences] - # loss = model(input_ids=e.input_ids, labels=e.input_ids).loss - - loss.backward() - optimizer.step() - - if epoch % interval == 0: - print(f"--------------------------EPOCH {epoch}--------------------------") - print("REWARDS:", rewards) - # print("SELECTED_PROBS:",selected_probs) - print("LOSS:", loss) - print("GENERATED:", gen_sentences) - print("TARGETS:", sentences) - # print(log_probs, list(model.named_parameters())[:3]) diff --git a/notebooks/Diversity exploration.ipynb b/notebooks/Diversity exploration.ipynb deleted file mode 100644 index aef1ff9..0000000 --- a/notebooks/Diversity exploration.ipynb +++ /dev/null @@ -1,492 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 13, - "id": "3979e885", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "import pandas as pd\n", - "from pathlib import Path\n", - "import sys\n", - "sys.path.append('..')\n", - "sys.path.append('../..')\n", - "###\n", - "import torchtext\n", - "import torch\n", - "from torchtext.data.utils import get_tokenizer\n", - "from torchtext.vocab import Vocab\n", - "###\n", - "from bs4 import BeautifulSoup\n", - "import re\n", - "###\n", - "import spacy\n", - "# nlp = spacy.load('en_core_web_sm', disable=[\"ner\"])\n", - "###\n", - "import random" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "59116071", - "metadata": {}, - "outputs": [], - "source": [ - "SEED = 1\n", - "SAMPLE_FRAC = 0.1\n", - "random.seed(SEED)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "f2e6f958", - "metadata": {}, - "outputs": [], - "source": [ - "### SST-2 ###\n", - "def parse_line(line):\n", - " index, sent = line.split('\\t')\n", - " if 'sentence_index' in index:\n", - " return (-1,'')\n", - " sent = re.sub('\\n', '', sent)\n", - " index = int(index) - 1\n", - " return (index, sent)\n", - "\n", - "def get_original_sst2():\n", - " # Load SST-2\n", - " sst_dir = data_dir / 'SST2-Data/SST2-Data/stanfordSentimentTreebank/stanfordSentimentTreebank'\n", - " fp = sst_dir / 'datasetSentences.txt'\n", - " sents = {}\n", - " with fp.open('r') as file:\n", - " for i, line in enumerate(file):\n", - " index, sent = parse_line(line)\n", - " if not (index < 0):\n", - " sents[index] = sent\n", - " return sents\n", - " \n", - "### IMDB Processing ###\n", - "#Removing the html strips\n", - "def strip_html(text):\n", - " soup = BeautifulSoup(text, \"html.parser\")\n", - " return soup.get_text()\n", - "\n", - "#Removing the square brackets\n", - "def remove_between_square_brackets(text):\n", - " return re.sub('\\[[^]]*\\]', '', text)\n", - "\n", - "#Removing the noisy text\n", - "def denoise_text(text):\n", - " text = strip_html(text)\n", - " text = remove_between_square_brackets(text)\n", - " return text" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "f9ef99f2", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[PosixPath('../../data/.DS_Store'),\n", - " PosixPath('../../data/SST2-Data'),\n", - " PosixPath('../../data/IMDB Dataset.csv')]" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_dir = Path('../../data')\n", - "list(data_dir.iterdir())" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "3cab6f7c", - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# # Load IMDB data\n", - "# df = pd.read_csv(data_dir/'IMDB Dataset.csv')\n", - "# print(f\"Loaded {len(df)} samples, randomly sampling {int(SAMPLE_FRAC * len(df))} rows\")\n", - "# # Sample percentage of data\n", - "# df = df.sample(frac=SAMPLE_FRAC, random_state=SEED)\n", - "# # Convert sentiment columns to numerical values\n", - "# df.sentiment = df.sentiment.apply(lambda x: 1 if x=='positive' else 0)\n", - "# df['review']=df['review'].apply(denoise_text)\n", - "# df.head(3)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "f0f69f29", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(11855,\n", - " \"The gorgeously elaborate continuation of `` The Lord of the Rings '' trilogy is so huge that a column of words can not adequately describe co-writer\\\\/director Peter Jackson 's expanded vision of J.R.R. Tolkien 's Middle-earth .\")" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# Load SST-2 and subsample\n", - "sents = get_original_sst2()\n", - "# num_samples = int(SAMPLE_FRAC * len(sents)) \n", - "# print(f\"Got {len(sents)} samples, randomly sampling {num_samples} samples\")\n", - "# sents = random.sample(sents, num_samples)\n", - "len(sents), sents[1]" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "4cf548ab", - "metadata": {}, - "outputs": [], - "source": [ - "# # Tokenize reviews\n", - "# tokenizer = get_tokenizer('spacy', language='en_core_web_sm')\n", - "# tokenized_texts = [tokenizer(seq) for seq in df.review]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "6eb6bb87", - "metadata": {}, - "outputs": [], - "source": [ - "# # Process reviews for sentences \n", - "# docs = []\n", - "# for doc in nlp.pipe(df.review):\n", - "# docs.append(doc)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "ae52f25b", - "metadata": {}, - "outputs": [], - "source": [ - "# Tokenize SST-2 Sentences\n", - "# tokenizer = get_tokenizer('spacy', language='en_core_web_sm')\n", - "# sents_tokenized = [tokenizer(sent) for sent in sents]\n", - "\n", - "from nltk.tokenize import word_tokenize\n", - "# sents_tokenized = list(map(word_tokenize, sents))\n", - "# sents_tokenized[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "b2099317", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "[nltk_data] Downloading package stopwords to\n", - "[nltk_data] /Users/Sameer/nltk_data...\n", - "[nltk_data] Package stopwords is already up-to-date!\n", - "[nltk_data] Downloading package punkt to /Users/Sameer/nltk_data...\n", - "[nltk_data] Package punkt is already up-to-date!\n", - "[nltk_data] Downloading package wordnet to /Users/Sameer/nltk_data...\n", - "[nltk_data] Package wordnet is already up-to-date!\n", - "[nltk_data] Downloading package words to /Users/Sameer/nltk_data...\n", - "[nltk_data] Package words is already up-to-date!\n", - "../../eda_nlp/code/eda.py:177: SyntaxWarning: \"is not\" with a literal. Did you mean \"!=\"?\n", - " words = [word for word in words if word is not '']\n" - ] - } - ], - "source": [ - "from TextGenerationEvaluationMetrics import multiset_distances as MSD\n", - "from DataAugmentation.data import augmentation\n", - "from eda_nlp.code.eda import get_only_chars\n", - "# from DataAugmentation.data import back_translation" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "30c7c225", - "metadata": {}, - "outputs": [], - "source": [ - "# from inspect import getmembers, isfunction\n", - "# print(getmembers(augmentation, isfunction))" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "076805ca", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(11855, 11855)" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "'''\n", - "Generating eda samples:\n", - "python eda_nlp/code/augment.py \n", - " --input=data/SST2-Data/SST2-Data/stanfordSentimentTreebank/stanfordSentimentTreebank/datasetSentences.txt \n", - " --output=./sst2_augmented.txt \n", - " --num_aug=5 --alpha_sr=0.3 --alpha_rd=0.1 --alpha_ri=0.1 --alpha_rs=0.0\n", - "'''\n", - "def get_augmented_sst2():\n", - " fp = Path('../../sst2_augmented.txt')\n", - " sents = {}\n", - " with fp.open('r') as file:\n", - " for i, line in enumerate(file):\n", - " index, sent = parse_line(line)\n", - " if not (index < 0):\n", - " if index in sents:\n", - " sents[index].append(sent)\n", - " else:\n", - " sents[index] = [sent]\n", - " return sents\n", - "\n", - "orig_sents = get_original_sst2()\n", - "orig_sents = {j:get_only_chars(t) for j, t in orig_sents.items()}\n", - "aug_sents = get_augmented_sst2()\n", - "len(orig_sents), len(aug_sents)" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "895edd7b", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "('the rock is destined to be the st century s new conan and that he s going to make a splash even greater than arnold schwarzenegger jean claud van damme or steven segal ',\n", - " ['the shake is bound to be the st hundred s newly conan and that he s going to make a splash even great than matthew arnold schwarzenegger blue jean claud caravan damme or steven george segal',\n", - " 'the sway is destined to be the st c s new conan and that he s live on to make believe a splash yet nifty than benedict arnold schwarzenegger jean claud new wave damme or steven george segal',\n", - " 'the destined to be the st century s new conan and that he s going to make a splash greater than arnold schwarzenegger jean claud van damme steven segal',\n", - " 'the rock is destined to be george segal the st century s new conan and that he s going to make a splash even matthew arnold greater than arnold schwarzenegger jean claud van damme or steven atomic number segal',\n", - " 'the rock is destined to be the st century s new conan and that he s going to make a splash even greater than arnold atomic number schwarzenegger jean atomic number claud van damme represent or steven segal',\n", - " 'the rock is destined to be the st century s new conan and that he s going to make a splash even greater than arnold schwarzenegger jean claud van damme or steven segal '])" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "orig_sents[0], aug_sents[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "299b938b", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Jaccard distances preprocess upto 5!\n" - ] - }, - { - "data": { - "text/plain": [ - "{1: 0.0233949945593036, 2: 0.0, 3: 0.0, 4: 0.0, 5: 0.0}" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "index = 1\n", - "\n", - "# ref1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'that', 'ensures', 'that', 'the', 'military', 'will', 'forever', 'heed', 'Party', 'commands']\n", - "# ref2 = ['It', 'is', 'the', 'guiding', 'principle', 'which', 'guarantees', 'the', 'military', 'forces', 'always', 'being', 'under', 'the', 'command', 'of', 'the', 'Party']\n", - "# ref3 = ['It', 'is', 'the', 'practical', 'guide', 'for', 'the', 'army', 'always', 'to', 'heed', 'the', 'directions', 'of', 'the', 'party']\n", - "# sen1 = ['It', 'is', 'a', 'guide', 'to', 'action', 'which', 'ensures', 'that', 'the', 'military', 'always', 'obeys', 'the', 'commands', 'of', 'the', 'party']\n", - "# sen2 = ['he', 'read', 'the', 'book', 'because', 'he', 'was', 'interested', 'in', 'world', 'history']\n", - "\n", - "references = map(word_tokenize, aug_sents[index])\n", - "sentences = map(word_tokenize, [orig_sents[index]])\n", - "sentences, references = map(list, (sentences, references))\n", - "\n", - "msd = MSD.MultisetDistances(references=references, min_n=1, max_n=5)\n", - "msj_distance = msd.get_jaccard_score(sentences=sentences)\n", - "msj_distance\n" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "1e99f1e1", - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "data": { - "text/plain": [ - "['the gorgeously elaborated continuation of the lord of the anchor ring trilogy is so brobdingnagian that a tower of words can not adequately draw co writer director tool old hickory s flesh out visual sensation of joule r r tolkien s middle ground',\n", - " 'the gorgeously work out continuance of the lord of the reverberate trilogy is so brobdingnagian that a newspaper column of words can not adequately discover co author managing director peter andrew jackson s expanded vision of j radius radius tolkien s middle terra firma',\n", - " 'the gorgeously elaborate continuation of the lord of the rings trilogy is so huge that a column of words can not adequately describe co writer director peter jackson s expanded vision and then of j r r tolkien immense michael joe jackson s middle earth',\n", - " 'the gorgeously of the lord of the rings trilogy is so that a column of words can not adequately co director peter s expanded vision of j r r tolkien s middle earth',\n", - " 'the gorgeously elaborate continuation of the lord the rings trilogy is so huge that a column of words can not adequately describe co director peter jackson s expanded vision of j r r tolkien s middle earth',\n", - " 'the gorgeously elaborate continuation of the lord of the rings trilogy is so huge that a column of words can not adequately describe co writer director peter jackson s expanded vision of j r r tolkien s middle earth']" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "[' '.join(s) for s in references]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6f4db50e", - "metadata": {}, - "outputs": [], - "source": [ - "' '.join(sentences[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "e0b2f334", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['the',\n", - " 'gorgeously',\n", - " 'elaborate',\n", - " 'continuation',\n", - " 'of',\n", - " 'the',\n", - " 'lord',\n", - " 'of',\n", - " 'the',\n", - " 'rings',\n", - " 'trilogy',\n", - " 'is',\n", - " 'so',\n", - " 'huge',\n", - " 'that',\n", - " 'a',\n", - " 'column',\n", - " 'of',\n", - " 'words',\n", - " 'can',\n", - " 'not',\n", - " 'adequately',\n", - " 'describe',\n", - " 'co',\n", - " 'writer',\n", - " 'director',\n", - " 'peter',\n", - " 'jackson',\n", - " 's',\n", - " 'expanded',\n", - " 'vision',\n", - " 'of',\n", - " 'j',\n", - " 'r',\n", - " 'r',\n", - " 'tolkien',\n", - " 's',\n", - " 'middle',\n", - " 'earth']" - ] - }, - "execution_count": 31, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sentences[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b9dc23dd", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "dda", - "language": "python", - "name": "dda" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.0" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/package.json b/package.json new file mode 100755 index 0000000..b026c76 --- /dev/null +++ b/package.json @@ -0,0 +1,47 @@ +{ + "name": "starter-react-app", + "version": "1.0.0", + "description": "A starter react app based on the latest standards with Typescript.", + "main": "src/index.tsx", + "scripts": { + "start": "react-scripts start", + "build": "react-scripts build", + "test": "react-scripts test", + "eject": "react-scripts eject" + }, + "dependencies": { + "@types/jest": "^26.0.15", + "@types/node": "^12.0.0", + "@types/react": "^17.0.0", + "@types/react-dom": "^17.0.0", + "firebase": "^8.2.1", + "react": "^17.0.2", + "react-dom": "^17.0.2", + "react-scripts": "4.0.3", + "typescript": "^4.1.2" + }, + "devDependencies": { + "@typescript-eslint/eslint-plugin": "^4.5.0", + "@typescript-eslint/parser": "^4.5.0", + "eslint": "^7.11.0", + "eslint-plugin-react": "^7.21.5" + }, + "eslintConfig": { + "extends": [ + "react-app", + "react-app/jest" + ] + }, + "browserslist": { + "production": [ + ">0.2%", + "not dead", + "not op_mini all" + ], + "development": [ + "last 1 chrome version", + "last 1 firefox version", + "last 1 safari version" + ] + } +} \ No newline at end of file diff --git a/public/index.html b/public/index.html new file mode 100755 index 0000000..88a13f6 --- /dev/null +++ b/public/index.html @@ -0,0 +1,12 @@ + + + + + + Starter React App + + +
+ + + \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index f7ee9a9..0000000 --- a/requirements.txt +++ /dev/null @@ -1,37 +0,0 @@ -aiohttp==3.7.4.post0 -async-timeout==3.0.1 -attrs==21.2.0 -certifi==2021.5.30 -chardet==4.0.0 -charset-normalizer==2.0.6 -click==8.0.1 -datasets==1.12.1 -dill==0.3.4 -filelock==3.3.0 -fsspec==2021.10.0 -huggingface-hub==0.0.19 -idna==3.2 -joblib==1.1.0 -multidict==5.2.0 -multiprocess==0.70.12.2 -numpy==1.21.2 -packaging==21.0 -pandas==1.3.3 -pyarrow==5.0.0 -pyparsing==2.4.7 -python-dateutil==2.8.2 -pytz==2021.3 -PyYAML==5.4.1 -regex==2021.9.30 -requests==2.26.0 -sacremoses==0.0.46 -six==1.16.0 -tokenizers==0.10.3 -torch==1.9.1 -tqdm==4.62.3 -transformers==4.11.0 -typing-extensions==3.10.0.2 -urllib3==1.26.7 -xxhash==2.0.2 -yarl==1.7.0 -DELETE THIS diff --git a/reward.py b/reward.py deleted file mode 100644 index b4a4192..0000000 --- a/reward.py +++ /dev/null @@ -1,118 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict -import editdistance -import transformers -from nltk import ngrams -from sentence_transformers import SentenceTransformer -import clf_sst2 -import torch -import utils -import numpy as np -# class SentencePairReward(ABC): -# ''' -# SentencePairReward is a base class for rewards that take in a pair of sentences and computes single or multiple reward between them e.g. -# diversity between input and generated sentences. -# ''' -# def __init__(self, weight: float): -# ''' -# Weight should be the weight for this reward -# ''' -# pass - -# @abstractmethod -# def reward(s1: str, s2: str) -> float: -# pass - -# class EditDistanceReward(SentencePairReward): -# ''' -# Edit distance reward -# ''' -# def __init(self): - - -class RewardWrapper(): - ''' - RewardWrapper for reward functions and state. - - Assume for all inputs s1, s2 unless said otherwise: - s1 (str): input sentence - s2 (str): generated sentence - ''' - embedder_key = 'all-distilroberta-v1' - - def __init__(self, clf = None, device=utils.DEV): - self.device = device - self.clf = clf_sst2.DistilBertSST(device=device) - self.embedder = SentenceTransformer(self.embedder_key, device=device) - - - def edit_distance(self, s1: str, s2: str) -> float: - ''' - Return normalized edit distance in [0, 1]. - Note, this is Levenshtein distance. - ''' - - N = len(s1) + len(s2) - ed = editdistance.eval(s1, s2) / N - r = 1 - ed - return r - - def iou_ungrams(self, s1: str, s2: str, n = 1) -> float: - ''' - Get ngram overlap between s1 and s2 as fraction in [0, 1] - IOU(ngrams(s1), ngrams(s2)) - ''' - def get_ngram_set(s, n=n): - toks = s.split() - if len(toks) < n: - return set() - return set(ngrams(toks, n=n)) - - ng1, ng2 = get_ngram_set(s1), get_ngram_set(s2) - score = 0 - if len(ng1) > 0 and len(ng2) > 0: - score = len(ng1 & ng2) / len(ng1 | ng2) - return score - - def embed_similarity(self, s1: str, s2: str): - e1, e2 = self.embedder.encode([s1, s2], convert_to_tensor=True) - return torch.nn.functional.cosine_similarity(e1, e2, dim=0).item() - - def clf_consistency(self, s1: str, s2: str, y: int) -> float: - ''' - Consistency reward from classifier. - Inputs: - y: label for s1 - ''' - yhat = self.clf.predict_on_text(s2)[0] - # return 1. if yhat == y else -1 - return 1. if yhat == y else 0 - - def compute_rewards(self, s1: str, s2: str, y: int, beta=1.0) -> Dict[str, float]: - ''' - Inputs: - y: label for s1 - ''' - iou_ngram_scores = np.asarray([pow(beta, n) * self.iou_ungrams(s1, s2, n=n) for n in (1, 2, 3)]) - return { - "edit_distance": self.edit_distance(s1, s2), - "iou_ungrams": iou_ngram_scores.mean(), - "embed_similarity": self.embed_similarity(s1, s2), - "clf_consistency": self.clf_consistency(s1, s2, y) - } - - -def main(): - reward_f = RewardWrapper() - N = 2 - tests = [ - ( "hi my name is sameer", "hi my name is bob", 0), - ("hi my name is sameer", "hi my name is sam", 0), - ("hi my name is sameer", "i hate bob hes a bad person", 0) - ] - for t in tests: - rewards = reward_f.compute_rewards(*t) - print(rewards) - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/shared_dependencies.md b/shared_dependencies.md new file mode 100755 index 0000000..51f561b --- /dev/null +++ b/shared_dependencies.md @@ -0,0 +1,23 @@ +1. **React**: All the files in the `src` directory will share the React library as a dependency. This includes the use of React components, hooks, and JSX. + +2. **Typescript**: All the `.tsx` files will share Typescript as a dependency. This includes the use of Typescript types, interfaces, and syntax. + +3. **Firebase Authentication**: The `auth.ts` service and the `Login.tsx`, `SignUp.tsx`, and `Logout.tsx` components will share Firebase Authentication as a dependency. This includes the use of Firebase's authentication methods and user object. + +4. **User Type**: The `user.ts` file will export a User type that will be shared by the `auth.ts` service and the `Login.tsx`, `SignUp.tsx`, and `Logout.tsx` components. + +5. **Auth Service**: The `auth.ts` file will export authentication functions that will be shared by the `Login.tsx`, `SignUp.tsx`, and `Logout.tsx` components. + +6. **Firebase Utility**: The `firebase.ts` utility file will be shared by the `auth.ts` service and potentially other files that require Firebase functionality. + +7. **CSS Styles**: The `global.css`, `login.css`, `signup.css`, and `logout.css` files will be shared by the respective components that require these styles. + +8. **DOM Element IDs**: The `Login.tsx`, `SignUp.tsx`, and `Logout.tsx` components will likely share DOM element IDs for form inputs and buttons that will be used by the authentication functions. + +9. **ProtectedRoute Component**: The `ProtectedRoute.tsx` component will be shared by any routes that require authentication. + +10. **Package.json**: All files will share the dependencies listed in the `package.json` file. + +11. **tsconfig.json**: All Typescript files will share the configuration specified in the `tsconfig.json` file. + +12. **index.html**: All components will be rendered into the root DOM element specified in the `index.html` file. \ No newline at end of file diff --git a/src/App.tsx b/src/App.tsx new file mode 100755 index 0000000..5b95149 --- /dev/null +++ b/src/App.tsx @@ -0,0 +1,22 @@ +import React from 'react'; +import { BrowserRouter as Router, Route, Switch } from 'react-router-dom'; +import Login from './components/Login'; +import SignUp from './components/SignUp'; +import Logout from './components/Logout'; +import ProtectedRoute from './components/ProtectedRoute'; +import './styles/global.css'; + +const App: React.FC = () => { + return ( + + + + + + + + + ); +} + +export default App; \ No newline at end of file diff --git a/src/components/Login.tsx b/src/components/Login.tsx new file mode 100755 index 0000000..e48513c --- /dev/null +++ b/src/components/Login.tsx @@ -0,0 +1,44 @@ +import React, { useState } from 'react'; +import { useHistory } from 'react-router-dom'; +import { login } from '../services/auth'; +import '../styles/login.css'; + +const Login: React.FC = () => { + const [email, setEmail] = useState(''); + const [password, setPassword] = useState(''); + const history = useHistory(); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + try { + await login(email, password); + history.push('/'); + } catch (error) { + alert(error.message); + } + }; + + return ( +
+
+ setEmail(e.target.value)} + required + /> + setPassword(e.target.value)} + required + /> + +
+
+ ); +}; + +export default Login; \ No newline at end of file diff --git a/src/components/Logout.tsx b/src/components/Logout.tsx new file mode 100755 index 0000000..8e90579 --- /dev/null +++ b/src/components/Logout.tsx @@ -0,0 +1,27 @@ +import React from 'react'; +import { useHistory } from 'react-router-dom'; +import { logout } from '../services/auth'; +import '../styles/logout.css'; + +const Logout: React.FC = () => { + const history = useHistory(); + + const handleLogout = async () => { + try { + await logout(); + history.push('/login'); + } catch (error) { + console.error(error); + } + }; + + return ( +
+ +
+ ); +}; + +export default Logout; \ No newline at end of file diff --git a/src/components/ProtectedRoute.tsx b/src/components/ProtectedRoute.tsx new file mode 100755 index 0000000..5e14390 --- /dev/null +++ b/src/components/ProtectedRoute.tsx @@ -0,0 +1,24 @@ +import React from 'react'; +import { Route, Redirect } from 'react-router-dom'; +import { useAuth } from '../services/auth'; + +interface ProtectedRouteProps { + component: React.FC; + path: string; + exact?: boolean; +} + +const ProtectedRoute: React.FC = ({ component: Component, ...rest }) => { + const { currentUser } = useAuth(); + + return ( + + currentUser ? : + } + /> + ); +}; + +export default ProtectedRoute; \ No newline at end of file diff --git a/src/components/SignUp.tsx b/src/components/SignUp.tsx new file mode 100755 index 0000000..7166b1d --- /dev/null +++ b/src/components/SignUp.tsx @@ -0,0 +1,48 @@ +import React, { useState } from 'react'; +import { useHistory } from 'react-router-dom'; +import { signUp } from '../services/auth'; +import '../styles/signup.css'; + +const SignUp: React.FC = () => { + const [email, setEmail] = useState(''); + const [password, setPassword] = useState(''); + const [error, setError] = useState(null); + const history = useHistory(); + + const handleSignUp = async (event: React.FormEvent) => { + event.preventDefault(); + try { + await signUp(email, password); + history.push('/'); + } catch (error) { + setError(error.message); + } + }; + + return ( +
+
+ setEmail(e.target.value)} + required + /> + setPassword(e.target.value)} + required + /> + +
+ {error &&

{error}

} +
+ ); +}; + +export default SignUp; \ No newline at end of file diff --git a/src/index.tsx b/src/index.tsx new file mode 100755 index 0000000..68c4732 --- /dev/null +++ b/src/index.tsx @@ -0,0 +1,14 @@ +import React from 'react'; +import ReactDOM from 'react-dom'; +import './styles/global.css'; +import App from './App'; +import * as serviceWorker from './serviceWorker'; + +ReactDOM.render( + + + , + document.getElementById('root') +); + +serviceWorker.unregister(); \ No newline at end of file diff --git a/src/services/auth.ts b/src/services/auth.ts new file mode 100755 index 0000000..0e4d182 --- /dev/null +++ b/src/services/auth.ts @@ -0,0 +1,41 @@ +import firebase from '../utils/firebase'; +import { User } from '../types/user'; + +export const signUp = async (email: string, password: string): Promise => { + try { + const response = await firebase.auth().createUserWithEmailAndPassword(email, password); + return { + uid: response.user?.uid, + email: response.user?.email, + }; + } catch (error) { + console.error(error); + return null; + } +}; + +export const login = async (email: string, password: string): Promise => { + try { + const response = await firebase.auth().signInWithEmailAndPassword(email, password); + return { + uid: response.user?.uid, + email: response.user?.email, + }; + } catch (error) { + console.error(error); + return null; + } +}; + +export const logout = async (): Promise => { + try { + await firebase.auth().signOut(); + } catch (error) { + console.error(error); + } +}; + +export const getCurrentUser = (): User | null => { + const user = firebase.auth().currentUser; + return user ? { uid: user.uid, email: user.email } : null; +}; \ No newline at end of file diff --git a/src/styles/global.css b/src/styles/global.css new file mode 100755 index 0000000..01da989 --- /dev/null +++ b/src/styles/global.css @@ -0,0 +1,26 @@ +/* src/styles/global.css */ + +body { + margin: 0; + padding: 0; + font-family: Arial, sans-serif; + background-color: #f4f4f4; +} + +.container { + max-width: 1200px; + margin: 0 auto; + padding: 0 15px; +} + +button { + cursor: pointer; +} + +input, button { + margin-top: 10px; +} + +.error { + color: red; +} \ No newline at end of file diff --git a/src/styles/login.css b/src/styles/login.css new file mode 100755 index 0000000..49803a6 --- /dev/null +++ b/src/styles/login.css @@ -0,0 +1,42 @@ +.login-container { + display: flex; + justify-content: center; + align-items: center; + height: 100vh; + background-color: #f5f5f5; +} + +.login-form { + width: 300px; + padding: 16px; + background-color: #fff; + border-radius: 8px; + box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1); +} + +.login-form input { + width: 100%; + padding: 10px; + margin-bottom: 10px; + border: 1px solid #ddd; + border-radius: 4px; +} + +.login-form button { + width: 100%; + padding: 10px; + border: none; + border-radius: 4px; + background-color: #007bff; + color: #fff; + cursor: pointer; +} + +.login-form button:hover { + background-color: #0056b3; +} + +.login-form .error-message { + color: red; + margin-bottom: 10px; +} \ No newline at end of file diff --git a/src/styles/logout.css b/src/styles/logout.css new file mode 100755 index 0000000..11126f6 --- /dev/null +++ b/src/styles/logout.css @@ -0,0 +1,21 @@ +.logout-container { + display: flex; + justify-content: center; + align-items: center; + height: 100vh; + background-color: #f5f5f5; +} + +.logout-button { + padding: 10px 20px; + font-size: 16px; + border: none; + border-radius: 5px; + background-color: #007bff; + color: white; + cursor: pointer; +} + +.logout-button:hover { + background-color: #0056b3; +} \ No newline at end of file diff --git a/src/styles/signup.css b/src/styles/signup.css new file mode 100755 index 0000000..beb4d36 --- /dev/null +++ b/src/styles/signup.css @@ -0,0 +1,42 @@ +.signup-container { + display: flex; + justify-content: center; + align-items: center; + height: 100vh; + background-color: #f5f5f5; +} + +.signup-form { + width: 300px; + padding: 20px; + background-color: #ffffff; + border-radius: 5px; + box-shadow: 0px 0px 10px 0px rgba(0,0,0,0.1); +} + +.signup-form input { + width: 100%; + padding: 10px; + margin-bottom: 10px; + border: 1px solid #ddd; + border-radius: 5px; +} + +.signup-form button { + width: 100%; + padding: 10px; + background-color: #007bff; + border: none; + border-radius: 5px; + color: #ffffff; + cursor: pointer; +} + +.signup-form button:hover { + background-color: #0056b3; +} + +.signup-form .error-message { + color: #ff0000; + margin-bottom: 10px; +} \ No newline at end of file diff --git a/src/types/user.ts b/src/types/user.ts new file mode 100755 index 0000000..feffaa7 --- /dev/null +++ b/src/types/user.ts @@ -0,0 +1,6 @@ +export interface User { + uid: string; + email: string | null; + displayName: string | null; + photoURL: string | null; +} \ No newline at end of file diff --git a/src/utils/firebase.ts b/src/utils/firebase.ts new file mode 100755 index 0000000..492a1eb --- /dev/null +++ b/src/utils/firebase.ts @@ -0,0 +1,15 @@ +import firebase from "firebase/app"; +import "firebase/auth"; + +const firebaseConfig = { + apiKey: "YOUR_API_KEY", + authDomain: "YOUR_AUTH_DOMAIN", + projectId: "YOUR_PROJECT_ID", + storageBucket: "YOUR_STORAGE_BUCKET", + messagingSenderId: "YOUR_MESSAGING_SENDER_ID", + appId: "YOUR_APP_ID" +}; + +firebase.initializeApp(firebaseConfig); + +export default firebase; \ No newline at end of file diff --git a/train_bart_v2.py b/train_bart_v2.py deleted file mode 100644 index 8226b8a..0000000 --- a/train_bart_v2.py +++ /dev/null @@ -1,151 +0,0 @@ -import torch -from transformers import logging -from transformers import trainer_utils -from tqdm import tqdm -from typing import List, Dict, Union -import os -import numpy as np -from utils import get_adamw, DEV -from data import SSTLoader, TokenizerWrapper -from bart_rl import BartReinforce, load_bart_model, load_bart_tokenizer -from reward import RewardWrapper -import argparse -SEED = 42 - -def parse_args(): - p = argparse.ArgumentParser() - p.add_argument('--gpu', type=int, help="GPU id") - return p.parse_args() - -def mask_span(seq_ids: torch.LongTensor, tokenizer: TokenizerWrapper, span_range=(4, 5)): - start, end = 2, 4 - for i in range(start, end): - seq_ids[i] = tokenizer.mask_token_id # mask one token - -def compute_rewards(reward_f: RewardWrapper, inputs: List[str], - outputs: List[str], labels: torch.LongTensor, - subtract_mean: bool = False, verbose = False) -> torch.FloatTensor: - ''' - Get rewards between input and output sentences - - Returns: - :torch.FloatTensor of shape (batch_size, 1): Batch size length tensor of rewards - ''' - rewards = [] - for s1, s2, label in zip(inputs, outputs, labels.numpy()): - r_dict = reward_f.compute_rewards(s1, s2, label) - editd = r_dict['edit_distance'] # + 0.0 - # iou = r_dict['iou_ungrams'] - 0.5 - es = r_dict['embed_similarity'] - 0.7 - con = r_dict['clf_consistency'] - alpha = 0.1 - r1 = (1-alpha) * (con + es) - # r2 = alpha * editd - r2 = 0 - rewards.append(r1 + r2) - # print(con, es, r1) - # print(r2) - rewards = torch.as_tensor(rewards, device=reward_f.device) - if subtract_mean and rewards.shape[0] > 1: - rewards = rewards - rewards.mean() - return rewards - -def call_config_functions(): - trainer_utils.set_seed(SEED) - torch.autograd.set_detect_anomaly(True) - logging.set_verbosity_error() - - -def main(): - args = parse_args() - call_config_functions() - device = DEV - if args.gpu is not None: - device = f"cuda:{args.gpu}" - - # Training config - batch_size = 1 - epochs = 100 - print_interval = 10 # epochs // 5 - verbose = False - use_tqdm = True - lim = 10 - reward_baseline_sub_mean = True - # Model config - freeze_encoder_params = True - optim_config = dict( - lr=5e-6, - wd=0.01 - ) - episode_config = dict( - epsilon=0.00, - temperature=0.7, - topk=200, - min_length=15, - verbose=verbose - ) - max_out_len = lambda input_len: int(input_len * 1.3) - - # Load data, model, optimizer - bart = load_bart_model() - optimizer = get_adamw(bart, **optim_config) - tokenizer = TokenizerWrapper(load_bart_tokenizer()) - # Load data into batches - sst2 = SSTLoader(tokenizer, batch_size=batch_size, lim=lim) - train_loader = sst2.get_train_loader() - - # Init seq2seq RL model and reward functions - rl_model = BartReinforce(bart, device=device) - if freeze_encoder_params: - rl_model.freeze_encoder_params() - reward_f = RewardWrapper(device=device) - # Train - for epoch in range(epochs): - with tqdm(train_loader, unit="batch") as tepoch: - tepoch.set_description(f"Epoch {epoch}") - # Train epoch - bart.train() - for batch_num, batch in enumerate(tepoch): # batch = Tuple(input_ids, attention_mask, labels) - # Generate episodes from batch of input - max_in_len = batch[0].shape[1] - out = rl_model.generate_episodes(batch, max_length=max_out_len(max_in_len), **episode_config) - # Decode sequences to sentence strings - input_sentences, output_sentences = map(tokenizer.decode, (batch[0], out)) - output_sentences = [sst2.preprocess_sentence(s) for s in output_sentences] - # Compute rewards - r = compute_rewards(reward_f, input_sentences, output_sentences, - labels=batch[2], subtract_mean=reward_baseline_sub_mean, verbose=verbose) - # Compute loss - # rl_model.log_probs = torch.nan_to_num(rl_model.log_probs, neginf=0) - # loss_batch = (r.unsqueeze(-1) * -rl_model.log_probs).sum(1) - # loss = loss_batch.mean() - loss = torch.tensor([0.0], requires_grad=True, device=device) - for i in range(batch_size): - z = -rl_model.log_probs[i] - z = z[~torch.isinf(z)] - loss = loss + (r[i] * z).sum() - loss = loss/batch_size - # import pdb; pdb.set_trace() - # Backward - loss.backward() - # try: - # torch.nn.utils.clip_grad_norm_(bart.parameters(), 1.0, error_if_nonfinite=True) - # except Exception as e: - # print(list(bart.parameters())[-1].grad) - # import pdb; pdb.set_trace() - # _ = 1 - optimizer.step() - # Clear state - rl_model.clear_episode_batch() - optimizer.zero_grad() - # Print info to stdout - if print_interval > 0 and epoch % print_interval == 0: - [ - tqdm.write(f"Input:\n{i}\nOutput:\n{o}\n") for i,o in zip(input_sentences, output_sentences) - ] - tqdm.write(f"Loss: {loss.item()}") - tqdm.write(f"Rewards: {r.squeeze().cpu()}") - tqdm.write(f"{'-'*15}") - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/tsconfig.json b/tsconfig.json new file mode 100755 index 0000000..8b27864 --- /dev/null +++ b/tsconfig.json @@ -0,0 +1,25 @@ +{ + "compilerOptions": { + "target": "es5", + "lib": [ + "dom", + "dom.iterable", + "esnext" + ], + "allowJs": true, + "skipLibCheck": true, + "esModuleInterop": true, + "allowSyntheticDefaultImports": true, + "strict": true, + "forceConsistentCasingInFileNames": true, + "module": "esnext", + "moduleResolution": "node", + "resolveJsonModule": true, + "isolatedModules": true, + "noEmit": true, + "jsx": "react-jsx" + }, + "include": [ + "src" + ] +} \ No newline at end of file diff --git a/utils.py b/utils.py deleted file mode 100644 index c22ff78..0000000 --- a/utils.py +++ /dev/null @@ -1,54 +0,0 @@ -from transformers import ( - BartTokenizerFast, BartForConditionalGeneration, - AdamW -) -import torch -import numpy as np -from typing import Union - -def seed_everything(seed: int): - import random, os - import numpy as np - import torch - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - -# SEED -seed_everything(42) -# DEVICE -if torch.cuda.is_available(): - DEV = "cuda" -else: - DEV = "cpu" -LOG_EPS = 0 # 1e-8 -MODEL_KEY = 'sshleifer/distilbart-cnn-6-6' -# MODEL_KEY = 'facebook/bart-base' - -def get_adamw(model, lr=1e-5, eps=1e-8, wd=0.): - ''' - Get adamw optimizer - ''' - return AdamW(model.parameters(), lr=lr, eps=eps, weight_decay=wd) - -def convert_to_numpy(arr: Union[np.ndarray, torch.Tensor, list]) -> np.ndarray: - ''' - Convert array to numpy - ''' - if isinstance(arr, torch.Tensor): - arr = arr.cpu().numpy() - if isinstance(arr, list): - arr = np.asarray(arr) - return arr - -def flat_accuracy(logits, labels) -> float: - ''' - Helper for accuracy between logits and labels in Tensor format or numpy arrays - ''' - logits, labels = map(convert_to_numpy, (logits, labels)) - pred_flat = np.argmax(logits, axis=1).flatten() - labels_flat = labels.flatten() - return np.sum(pred_flat == labels_flat) / len(labels_flat) - -