Skip to content

feature: integrate tile db loading in data_loader#6

Open
VittorioRossi wants to merge 7 commits into
mainfrom
vitto
Open

feature: integrate tile db loading in data_loader#6
VittorioRossi wants to merge 7 commits into
mainfrom
vitto

Conversation

@VittorioRossi

Copy link
Copy Markdown
Collaborator

No description provided.

@giacomo-ciro giacomo-ciro left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would modify the dataloader to return:

return {
            "input_ids": input_ids,             # [max_length,]: LongTensor
            "eos_labels": eos_labels,           # [max_length,]: LongTensor
            "response_length": response_length, # [1,]: FloatTensor
            "input_mask": input_mask            # [max_length,]: BoolTensor,
            ++ "input_embeddings: input_embeddings # [max_length, d_model]: FloatTensor
        }

Effectively, the input_ids won't be needed anymore, but I suggest we keep them just in case (the pipeline is already in place, just ignore them).

And also modify the models' forward pass (LLaDaRegressor.forward() and LLaDaClassifier.forward() in models/llada.py) accordingly to take directly the embeddings as input and work on those.

As of now, the model takes only input_ids and calls hidden_state = self.llada.get_last_hidden_state(input_ids), since we have the embeddings stored on disk we can skip this part and directly go with hidden_state = input_embeddings.

@giacomo-ciro

Copy link
Copy Markdown
Owner

@VittorioRossi ho fatto il wrapper al trainer in trainer.py e fatto lo script train.py che ora e' super clean (5 righe).

Rimane da adattare il forward pass dei modelli e testare tutto (io non ho ancora runnato il mio perche mancano delle cose,quindi ci saranno bug, ma una volta terminato tutto il refactor li fixiamo

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants