-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
78 lines (70 loc) · 2.43 KB
/
train.py
File metadata and controls
78 lines (70 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import lightning as L
import torch
from cyclopts import App
from lightning.pytorch import seed_everything
from lightning.pytorch.callbacks import (
GradientAccumulationScheduler,
LearningRateMonitor,
ModelCheckpoint,
StochasticWeightAveraging,
)
from lightning.pytorch.loggers import TensorBoardLogger
from src.configs import ModelTrainConfig
from src.data_module import AmongUsDatamodule
from src.models import ModelFcosPretrained
from src.utils import TestEveryNEpochs
app = App(name="Define Config for training:")
@app.command
def train_fcos(cfg: ModelTrainConfig = ModelTrainConfig()):
training_cfg = cfg.training_cfg
seed_everything(training_cfg.seed)
torch.set_float32_matmul_precision("high")
# initialize Datamodule
data_module = AmongUsDatamodule(
cfg.datamodule_cfg, cfg.creation_cfg, cfg.transform_cfg
)
# Setup TensorBoard logger
tb_logger = TensorBoardLogger(
save_dir="logs", name=training_cfg.logger_name, version=None
)
# Setup callbacks
checkpoint_callback = ModelCheckpoint(
monitor="mAP_score_val_generated",
mode="max",
save_top_k=3,
dirpath="checkpoints",
filename="fcos-{epoch:02d}-{loss_val:.4f}",
save_last=True,
enable_version_counter=True,
)
grad_acum = GradientAccumulationScheduler(
scheduling={int(k): v for k, v in training_cfg.grad_acum_scheduling.items()}
)
lr_monitor = LearningRateMonitor(logging_interval="step")
swa = StochasticWeightAveraging(
swa_lrs=training_cfg.swa_lrs,
swa_epoch_start=250,
annealing_epochs=10,
annealing_strategy="cos",
)
trainer = L.Trainer(
accelerator="gpu",
max_epochs=training_cfg.num_epochs,
logger=tb_logger,
callbacks=[checkpoint_callback, lr_monitor, swa, grad_acum],
log_every_n_steps=10,
gradient_clip_val=6.0,
enable_progress_bar=True,
limit_train_batches=training_cfg.train_epoch_len,
limit_val_batches=training_cfg.val_epoch_len,
reload_dataloaders_every_n_epochs=cfg.datamodule_cfg.generate_every_epoch,
)
if training_cfg.finetune_chk is not None:
model = ModelFcosPretrained.load_from_checkpoint(
training_cfg.finetune_chk, weights_only=False
)
else:
model = ModelFcosPretrained(cfg)
trainer.fit(model=model, datamodule=data_module)
if __name__ == "__main__":
app()