-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrainer.py
More file actions
108 lines (88 loc) · 3.39 KB
/
Trainer.py
File metadata and controls
108 lines (88 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from collections import defaultdict
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as nnf
from torch.optim import AdamW
#distributed training
from torch.utils.data import DataLoader, ConcatDataset
from torch.nn.parallel import DistributedDataParallel as DDP
#custom
from utils.StatsMaker import StatisticsMaker
from Evaluater import Evaluater
class Trainer:
def __init__(
self,
model: torch.nn.Module,
model_name: str,
train_dataloader: DataLoader,
test_dataloader: DataLoader,
optimizer: torch.optim.Optimizer,
gpu_id: int,
is_distributed: bool,
main_gpu_id = 0,
accum_iter = 1,
without_eval=False
) -> None:
self.gpu_id = gpu_id
self.without_eval = without_eval
self.is_distributed = is_distributed
if is_distributed:
self.stats = StatisticsMaker(model_name, gpu_id, main_gpu_id)
else:
self.stats = StatisticsMaker(model_name, gpu_id, gpu_id)
# print("trying initialization of dataloaders")
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.accum_iter = accum_iter
self.count_iter = 0
self.optimizer = optimizer
self.eval_every = 1
self.save_every = 1
self.model = model.to(gpu_id)
if self.is_distributed:
self.model = DDP(self.model, device_ids=[gpu_id])#,static_graph=True)
self.module = self.model.module
else:
self.module = self.model
self.evaluater = Evaluater(self.model, self.test_dataloader, self.gpu_id, is_distributed=self.is_distributed)
def _run_batch(self, audio, targets):
logits = self.model(audio)
loss = nnf.cross_entropy(logits, targets)
loss = loss / self.accum_iter
self.count_iter += 1
self.stats.add_loss(loss.item()*self.accum_iter)
loss.backward()
if self.count_iter == self.accum_iter:
self.optimizer.step()
self.optimizer.zero_grad()
self.count_iter = 0
def _run_train_epoch(self, epoch):
if self.is_distributed:
self.train_dataloader.sampler.set_epoch(epoch)
self.model.train()
for audio, targets in self.stats.make_pbar(self.train_dataloader):
audio = audio.to(self.gpu_id)
targets = targets.to(self.gpu_id)
self._run_batch(audio, targets)
self.stats.set_description()
self.stats.save_epoch_loss()
def _run_eval(self, epoch):
# if self.is_distributed:
# self.test_dataloader.sampler.set_epoch(epoch)
metrics = self.evaluater.eval()
self.stats.process_metrics(self.module, metrics)
def train(self, max_epochs: int, start_epoch = 0):
for epoch in range(start_epoch, max_epochs):
self.stats.set_epoch(epoch)
self.stats.epoch_time_measure()
self._run_train_epoch(epoch)
if epoch % self.save_every == 0:
self.stats.save_last_params(self.module)
if epoch % self.eval_every == 0 and not self.without_eval:
self._run_eval(epoch)
self.stats.epoch_time_measure()
def evaluate(self):
self.evaluater.eval()