diff --git a/examples/diffusers/quantization/pipeline_manager.py b/examples/diffusers/quantization/pipeline_manager.py index f62aeffca98..5b46c66292a 100644 --- a/examples/diffusers/quantization/pipeline_manager.py +++ b/examples/diffusers/quantization/pipeline_manager.py @@ -191,19 +191,25 @@ def _ensure_ltx2_transformer_cached(self) -> None: if not self.pipe: raise RuntimeError("Pipeline not created. Call create_pipeline() first.") if self._transformer is None: - transformer = self.pipe.stage_1_model_ledger.transformer() - self.pipe.stage_1_model_ledger.transformer = lambda: transformer - self._transformer = transformer + self._transformer = self.pipe.stage_1_model_ledger.transformer() + # Route the ledger getter through this PipelineManager's attribute so the + # cached instance lives in exactly one controllable place (self._transformer). + # A `lambda: transformer` closure would capture the instance inside the + # lambda's cell, which prevents LTX-2's `del transformer` after stage 1 + # inference from releasing any references — it only drops the local binding. + pm = self + self.pipe.stage_1_model_ledger.transformer = lambda: pm._transformer def _ensure_ltx2_video_decoder_cached(self) -> None: if not self.pipe: raise RuntimeError("Pipeline not created. Call create_pipeline() first.") if self._video_decoder is None: - video_decoder = self.pipe.stage_1_model_ledger.video_decoder() - # Cache it so subsequent calls return the same (quantized) instance - self.pipe.stage_1_model_ledger.video_decoder = lambda: video_decoder - self.pipe.stage_2_model_ledger.video_decoder = lambda: video_decoder - self._video_decoder = video_decoder + self._video_decoder = self.pipe.stage_1_model_ledger.video_decoder() + # Cache it so subsequent calls return the same (quantized) instance. + # Same rationale as the transformer patch: route through self._video_decoder. + pm = self + self.pipe.stage_1_model_ledger.video_decoder = lambda: pm._video_decoder + self.pipe.stage_2_model_ledger.video_decoder = lambda: pm._video_decoder def _create_ltx2_pipeline(self) -> Any: params = dict(self.config.extra_params)