diff --git a/efold/core/callbacks.py b/efold/core/callbacks.py index 5593d0a..72677c4 100644 --- a/efold/core/callbacks.py +++ b/efold/core/callbacks.py @@ -1,27 +1,9 @@ -import os -from lightning import LightningModule, Trainer import lightning.pytorch as pl -import torch -import numpy as np -import pandas as pd -from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch import Trainer from lightning.pytorch.utilities import rank_zero_only import wandb -from typing import Any -from .visualisation import plot_factory -from .metrics import metric_factory -from .datamodule import DataModule from .loader import Loader -from .batch import Batch -from ..config import ( - TEST_SETS_NAMES, - REF_METRIC_SIGN, - REFERENCE_METRIC, - DATA_TYPES_TEST_SETS, - POSSIBLE_METRICS, -) -from .logger import Logger, LocalLogger class ModelCheckpoint(pl.Callback): diff --git a/efold/core/model.py b/efold/core/model.py index f854e95..c1f77a6 100644 --- a/efold/core/model.py +++ b/efold/core/model.py @@ -66,7 +66,9 @@ def configure_optimizers(self): if not hasattr(self, "gamma") or self.gamma is None: return optimizer - scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.gamma) + scheduler = torch.optim.lr_scheduler.LinearWarmupCosineAnnealingLR( + optimizer, warmup_epochs=3, max_epochs=self.max_epochs + ) return [optimizer], [scheduler] def _loss_signal(self, batch: Batch, data_type: str): diff --git a/scripts/efold_training.py b/scripts/efold_training.py index bc84a7f..34ddc8a 100644 --- a/scripts/efold_training.py +++ b/scripts/efold_training.py @@ -59,8 +59,10 @@ trainer = Trainer( accelerator=device, devices=n_gpu if STRATEGY == "ddp" else 1, - strategy=DDPStrategy(find_unused_parameters=False) if STRATEGY == "ddp" else 'auto', - max_epochs=15, + strategy=DDPStrategy(find_unused_parameters=False) + if STRATEGY == "ddp" + else "auto", + max_epochs=30, log_every_n_steps=1, accumulate_grad_batches=32, use_distributed_sampler=STRATEGY != "ddp", diff --git a/scripts/ribonanza-template.py b/scripts/ribonanza-template.py index 5a28d6d..69819d1 100644 --- a/scripts/ribonanza-template.py +++ b/scripts/ribonanza-template.py @@ -6,7 +6,7 @@ sys.path.append(os.path.dirname(os.path.dirname(__file__))) from lightning.pytorch import Trainer -from efold.core.callbacks import WandbFitLogger, KaggleLogger +from efold.core.callbacks import WandbFitLogger from efold.config import device from efold import DataModule, create_model import torch