diff --git a/sgm/modules/diffusionmodules/loss.py b/sgm/modules/diffusionmodules/loss.py index c397c658..0e000a5e 100644 --- a/sgm/modules/diffusionmodules/loss.py +++ b/sgm/modules/diffusionmodules/loss.py @@ -17,6 +17,7 @@ def __init__( loss_type: str = "l2", offset_noise_level: float = 0.0, batch2model_keys: Optional[Union[str, List[str]]] = None, + n_frames: Optional[int] = None, ): super().__init__() @@ -27,6 +28,7 @@ def __init__( self.loss_type = loss_type self.offset_noise_level = offset_noise_level + self.n_frames = n_frames if loss_type == "lpips": self.lpips = LPIPS().eval()