feature: integrate tile db loading in data_loader#6
Conversation
There was a problem hiding this comment.
I would modify the dataloader to return:
return {
"input_ids": input_ids, # [max_length,]: LongTensor
"eos_labels": eos_labels, # [max_length,]: LongTensor
"response_length": response_length, # [1,]: FloatTensor
"input_mask": input_mask # [max_length,]: BoolTensor,
++ "input_embeddings: input_embeddings # [max_length, d_model]: FloatTensor
}
Effectively, the input_ids won't be needed anymore, but I suggest we keep them just in case (the pipeline is already in place, just ignore them).
And also modify the models' forward pass (LLaDaRegressor.forward() and LLaDaClassifier.forward() in models/llada.py) accordingly to take directly the embeddings as input and work on those.
As of now, the model takes only input_ids and calls hidden_state = self.llada.get_last_hidden_state(input_ids), since we have the embeddings stored on disk we can skip this part and directly go with hidden_state = input_embeddings.
|
@VittorioRossi ho fatto il wrapper al trainer in Rimane da adattare il forward pass dei modelli e testare tutto (io non ho ancora runnato il mio perche mancano delle cose,quindi ci saranno bug, ma una volta terminato tutto il refactor li fixiamo |
… embeddings and not compute them
No description provided.