diff --git a/src/streamdiffusion/acceleration/tensorrt/builder.py b/src/streamdiffusion/acceleration/tensorrt/builder.py index f79fff72..731cb0d4 100644 --- a/src/streamdiffusion/acceleration/tensorrt/builder.py +++ b/src/streamdiffusion/acceleration/tensorrt/builder.py @@ -11,6 +11,7 @@ from .models.models import BaseModel from .utilities import ( + TRT_LOGGER, build_engine, export_onnx, optimize_onnx, @@ -291,7 +292,7 @@ def _quant_fn(): import tensorrt as trt - _rt = trt.Runtime(trt.Logger(trt.Logger.WARNING)) + _rt = trt.Runtime(TRT_LOGGER) with open(engine_path, "rb") as _f: _eng = _rt.deserialize_cuda_engine(_f.read()) _insp = _eng.create_engine_inspector() diff --git a/src/streamdiffusion/acceleration/tensorrt/utilities.py b/src/streamdiffusion/acceleration/tensorrt/utilities.py index f815d459..9fa1e12d 100644 --- a/src/streamdiffusion/acceleration/tensorrt/utilities.py +++ b/src/streamdiffusion/acceleration/tensorrt/utilities.py @@ -43,13 +43,14 @@ from polygraphy import cuda from polygraphy.backend.common import bytes_from_path from polygraphy.backend.trt import engine_from_bytes +from polygraphy.backend.trt.util import get_trt_logger from .models.models import CLIP, VAE, BaseModel, UNet, VAEEncoder logger = logging.getLogger(__name__) -TRT_LOGGER = trt.Logger(trt.Logger.ERROR) +TRT_LOGGER = get_trt_logger() # polygraphy singleton — shared with engine_from_bytes() from ...model_detection import detect_model @@ -624,13 +625,12 @@ def build( # set_preview_feature, or SPARSE_WEIGHTS. We use the raw API (same as # the FP8 path) so all parameters are available for both precision paths. - build_logger = trt.Logger(trt.Logger.WARNING) - builder = trt.Builder(build_logger) + builder = trt.Builder(TRT_LOGGER) network_flags = 0 network = builder.create_network(network_flags) - parser = trt.OnnxParser(network, build_logger) + parser = trt.OnnxParser(network, TRT_LOGGER) parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM) success = parser.parse_from_file(onnx_path) if not success: @@ -732,16 +732,14 @@ def _build_fp8( gpu_profile: Hardware-aware build parameters from detect_gpu_profile(). dynamic_shapes: Whether the engine uses dynamic input shapes. """ - build_logger = trt.Logger(trt.Logger.WARNING) - - builder = trt.Builder(build_logger) + builder = trt.Builder(TRT_LOGGER) # STRONGLY_TYPED: required for FP8. Tells TRT to use the data-type annotations # from Q/DQ nodes rather than running its own precision heuristics. network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED) network = builder.create_network(network_flags) - parser = trt.OnnxParser(network, build_logger) + parser = trt.OnnxParser(network, TRT_LOGGER) # NATIVE_INSTANCENORM: use TRT's fused InstanceNorm/GroupNorm kernel instead # of decomposing into primitive ops. Diffusion UNets use GroupNorm heavily. parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM) diff --git a/src/streamdiffusion/config.py b/src/streamdiffusion/config.py index b52e8233..c3bc1d05 100644 --- a/src/streamdiffusion/config.py +++ b/src/streamdiffusion/config.py @@ -498,8 +498,8 @@ def _validate_config(config: Dict[str, Any]) -> None: if not isinstance(seed_value, int) or seed_value < 0: raise ValueError(f"_validate_config: Seed value {i} must be a non-negative integer") - if not isinstance(weight, (int, float)) or weight < 0: - raise ValueError(f"_validate_config: Seed weight {i} must be a non-negative number") + if not isinstance(weight, (int, float)) or weight < 0: + raise ValueError(f"_validate_config: Seed weight {i} must be a non-negative number") interpolation_method = seed_blend_config.get("interpolation_method", "linear") if interpolation_method not in ["linear", "slerp"]: diff --git a/src/streamdiffusion/model_detection.py b/src/streamdiffusion/model_detection.py index fbd28933..e98d8e9b 100644 --- a/src/streamdiffusion/model_detection.py +++ b/src/streamdiffusion/model_detection.py @@ -169,7 +169,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str return result -def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, any]: +def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, Any]: """Detect detailed UNet characteristics including SDXL-specific features""" config = unet.config @@ -246,12 +246,12 @@ def detect_model_from_diffusers_unet(unet: UNet2DConditionModel) -> str: return "SD21" if cross_attention_dim == 768: - print( + logger.warning( f"detect_model_from_diffusers_unet: Unknown SD1.5-like model with channels {block_out_channels}, defaulting to SD15" ) return "SD15" elif cross_attention_dim == 1024: - print( + logger.warning( f"detect_model_from_diffusers_unet: Unknown SD2.1-like model with channels {block_out_channels}, defaulting to SD21" ) return "SD21" @@ -379,12 +379,18 @@ def validate_architecture(arch_dict: Dict[str, Any], model_type: str) -> Dict[st else: arch_dict[key] = tuple(arch_dict[key]) else: - arch_dict[key] = preset[key] + raise ValueError( + f"validate_architecture: '{key}' has unsupported type {type(arch_dict[key]).__name__!r}; " + f"expected list, int, or tuple" + ) # Validate sequence lengths match expected_levels = len(arch_dict["channel_mult"]) for key in ["num_res_blocks", "transformer_depth"]: if key in arch_dict and len(arch_dict[key]) != expected_levels: - arch_dict[key] = preset[key] + raise ValueError( + f"validate_architecture: '{key}' has {len(arch_dict[key])} levels but " + f"'channel_mult' has {expected_levels}; they must match" + ) return arch_dict diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 0079d966..539ad411 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -50,7 +50,7 @@ def __init__( normalize_seed_weights: bool = True, scheduler: Literal["lcm", "tcd"] = "lcm", sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal", - kvo_cache: List[torch.Tensor] = [], + kvo_cache: Optional[List[torch.Tensor]] = None, cache_interval: int = 1, cache_maxframes: int = 1, ) -> None: @@ -173,7 +173,7 @@ def __init__( self._cached_cfg_type: Optional[str] = None self._cached_guidance_scale: Optional[float] = None - self.kvo_cache = kvo_cache + self.kvo_cache = kvo_cache if kvo_cache is not None else [] self._kvo_buckets = None self._kvo_outputs_by_bucket = None self.cache_interval = cache_interval diff --git a/src/streamdiffusion/wrapper.py b/src/streamdiffusion/wrapper.py index 4b9d89fc..cd8aba4b 100644 --- a/src/streamdiffusion/wrapper.py +++ b/src/streamdiffusion/wrapper.py @@ -336,9 +336,11 @@ def __init__( self.use_denoising_batch = use_denoising_batch # safety checker is only supported for TensorRT acceleration self.use_safety_checker = use_safety_checker and (acceleration == "tensorrt") - self.set_nsfw_fallback_img(height, width) self.safety_checker_fallback_type = safety_checker_fallback_type self.safety_checker_threshold = safety_checker_threshold + # Caches the last clean (non-flagged) pipeline tensor for the "previous" fallback strategy. + # Operates in diffusion range [-1, 1]; set by _apply_safety_checker(). + self._prev_clean_tensor: Optional[torch.Tensor] = None self.fp8 = fp8 self.static_shapes = static_shapes self.fp8_allow_fp16_fallback = fp8_allow_fp16_fallback @@ -770,8 +772,6 @@ def _process_skip_diffusion( The processed image with hooks applied. """ - # TODO: add safety checker call somewhere in this method - if self.mode == "txt2img": raise RuntimeError( "_process_skip_diffusion: skip_diffusion mode not applicable for txt2img - no input image" @@ -798,6 +798,9 @@ def _process_skip_diffusion( # Apply image postprocessing hooks (expect [-1,1] range - post-VAE decoding) processed_tensor = self.stream._apply_image_postprocessing_hooks(processed_tensor) + # Screen skip-diffusion output too (raw [-1, 1] tensor, before postprocess/IPC export). + processed_tensor = self._apply_safety_checker(processed_tensor) + # Final postprocessing for output format return self.postprocess_image(processed_tensor, output_type=self.output_type) @@ -824,18 +827,9 @@ def txt2img(self, prompt: Optional[str] = None) -> Union[Image.Image, List[Image else: image_tensor = self.stream.txt2img(self.frame_buffer_size) + image_tensor = self._apply_safety_checker(image_tensor) image = self.postprocess_image(image_tensor, output_type=self.output_type) - if self.use_safety_checker: - if self.output_type != "pt": - denormalized_image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1).to(self.device) - else: - denormalized_image_tensor = image - if self.safety_checker(denormalized_image_tensor, self.safety_checker_threshold): - image = self.nsfw_fallback_img - elif self.safety_checker_fallback_type == "previous": - self.nsfw_fallback_img = image - return image def img2img( @@ -865,17 +859,8 @@ def img2img( # Full pipeline with diffusion image_tensor = self.stream(image) + image_tensor = self._apply_safety_checker(image_tensor) image = self.postprocess_image(image_tensor, output_type=self.output_type) - if self.use_safety_checker: - if self.output_type != "pt": - denormalized_image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1).to(self.device) - else: - denormalized_image_tensor = image - if self.safety_checker(denormalized_image_tensor, self.safety_checker_threshold): - image = self.nsfw_fallback_img - logger.info(f"NSFW content detected, falling back to {self.nsfw_fallback_img} frame") - elif self.safety_checker_fallback_type == "previous": - self.nsfw_fallback_img = image return image @@ -1050,7 +1035,7 @@ def cleanup_cuda_ipc(self) -> None: try: self._cuda_ipc_exporter.close() except Exception: - pass + logger.debug("cleanup_cuda_ipc: _cuda_ipc_exporter.close() failed", exc_info=True) self._cuda_ipc_exporter = None def _denormalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor: @@ -1136,29 +1121,42 @@ def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Ima return pil_images - def set_nsfw_fallback_img(self, height: int, width: int) -> None: - """ - Set the NSFW fallback image used when safety checker blocks content. + def _apply_safety_checker(self, image_tensor: torch.Tensor) -> torch.Tensor: + """Run the NSFW check on the raw pipeline tensor and substitute a fallback if flagged. - Creates a black RGB image of the specified dimensions that will be returned - when the safety checker determines content should be blocked. + This MUST be called *before* postprocess_image so that the substitution also covers the + CUDA-IPC export path: postprocess_image() exports the frame inside its own body and + returns None, making any post-hoc substitution unreachable and unsafe. Parameters ---------- - height : int - Height of the fallback image in pixels. - width : int - Width of the fallback image in pixels. + image_tensor : torch.Tensor + Raw pipeline output in diffusion range [-1, 1], NCHW. Returns ------- - None + torch.Tensor + The original tensor when content is clean, or a fallback tensor (previous clean + frame, or all-black encoded as -1.0 in diffusion range) when content is flagged. """ - self.nsfw_fallback_img = Image.new("RGB", (height, width), (0, 0, 0)) - if self.output_type == "pt": - self.nsfw_fallback_img = torch.from_numpy(np.array(self.nsfw_fallback_img)).unsqueeze(0) - elif self.output_type == "np": - self.nsfw_fallback_img = np.expand_dims(np.array(self.nsfw_fallback_img), axis=0) + if not self.use_safety_checker: + return image_tensor + + # Denormalize to [0, 1] NCHW for the classifier; stays on GPU. + denormalized = self._denormalize_on_gpu(image_tensor) + + if self.safety_checker(denormalized, self.safety_checker_threshold): + logger.info("NSFW content detected, applying safety fallback frame") + if self.safety_checker_fallback_type == "previous" and self._prev_clean_tensor is not None: + return self._prev_clean_tensor + # -1.0 in diffusion range → 0.0 after denormalization → true black on every output + # path (pt, np, pil, CUDA-IPC). + return torch.full_like(image_tensor, -1.0) + + # Content is clean — cache it for the "previous" fallback strategy. + if self.safety_checker_fallback_type == "previous": + self._prev_clean_tensor = image_tensor.clone() + return image_tensor def _load_model( self, @@ -1522,7 +1520,7 @@ def _load_model( try: stream.pipe.unload_lora_weights() except Exception: - pass + logger.debug("LoRA cleanup: unload_lora_weights() failed after merge failure", exc_info=True) if use_tiny_vae: if vae_id is not None: @@ -2603,7 +2601,7 @@ def cleanup_gpu_memory(self) -> None: self.stream._param_updater.clear_caches() logger.info(" Cleared prompt caches") except Exception: - pass + logger.debug("cleanup_gpu_memory: clear_caches() failed", exc_info=True) # Enhanced TensorRT engine cleanup if hasattr(self, "stream") and self.stream: @@ -2678,7 +2676,7 @@ def cleanup_gpu_memory(self) -> None: del self.stream logger.info(" Cleared stream object") except Exception: - pass + logger.debug("cleanup_gpu_memory: del self.stream failed", exc_info=True) self.stream = None # Release wrapper-level frame buffers so the next model swap allocates fresh diff --git a/tests/unit/test_safety_checker.py b/tests/unit/test_safety_checker.py new file mode 100644 index 00000000..ef5a3d13 --- /dev/null +++ b/tests/unit/test_safety_checker.py @@ -0,0 +1,233 @@ +""" +Regression tests for StreamDiffusionWrapper._apply_safety_checker. + +These tests are deliberately CPU-only and model-free. They construct a +minimal StreamDiffusionWrapper shell via object.__new__ and wire in only the +attributes the helper reads, so the full GPU/TRT stack is not required. + +Root cause being guarded: when use_cuda_ipc_output=True and output_type='pt', +postprocess_image() exports the frame and returns None. The old code fed that +None to self.safety_checker, which called torchvision T.Resize on None and +raised: + TypeError: Unexpected type +This crashed the streaming loop every frame. + +Fix: _apply_safety_checker() runs *before* postprocess_image, operating on the +raw diffusion-range [-1, 1] pipeline tensor so it is always a real tensor +regardless of output path. +""" + +import torch + +from streamdiffusion.wrapper import StreamDiffusionWrapper + + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + + +def _make_wrapper( + *, + use_safety_checker: bool = True, + fallback_type: str = "blank", + threshold: float = 0.5, +): + """Construct a minimal StreamDiffusionWrapper without model loading.""" + w = object.__new__(StreamDiffusionWrapper) + w.use_safety_checker = use_safety_checker + w.safety_checker_threshold = threshold + w.safety_checker_fallback_type = fallback_type + w._prev_clean_tensor = None + # safety_checker is set per-test + w.safety_checker = None + return w + + +def _black_denorm(t: torch.Tensor) -> bool: + """True when t is all-zeros after _denormalize_on_gpu (i.e. all -1.0 raw).""" + return torch.allclose(t, torch.zeros_like(t)) + + +# --------------------------------------------------------------------------- +# test cases +# --------------------------------------------------------------------------- + + +class TestApplySafetyChecker: + # ── Case 1: No-None contract (directly reproduces the old crash) ──────── + def test_checker_never_receives_none(self): + """ + The safety checker must never be called with a None tensor. + This is the exact condition that produced: + TypeError: Unexpected type + in the old post-hoc code path when output_type='pt' and IPC was active. + """ + received: list = [] + + def capturing_checker(tensor, thr): + received.append(tensor) + return True # always flag + + w = _make_wrapper() + w.safety_checker = capturing_checker + + dummy = torch.randn(1, 3, 64, 64) + w._apply_safety_checker(dummy) + + assert len(received) == 1 + arg = received[0] + assert arg is not None, "safety checker received None — regression" + assert isinstance(arg, torch.Tensor), f"expected Tensor, got {type(arg)}" + + # ── Case 2: NSFW → black frame (blank fallback) ────────────────────────── + def test_nsfw_blank_fallback_is_black(self): + w = _make_wrapper(fallback_type="blank") + w.safety_checker = lambda t, thr: True # always flag + + dummy = torch.randn(1, 3, 64, 64) + result = w._apply_safety_checker(dummy) + + assert result is not dummy, "flagged frame should be replaced" + # Denormalize: (x/2+0.5).clamp(0,1); -1.0 → 0.0 = black + denorm = (result / 2 + 0.5).clamp(0, 1) + assert _black_denorm(denorm), "NSFW blank fallback should produce a black frame" + + # ── Case 3: NSFW → previous frame ──────────────────────────────────────── + def test_nsfw_previous_fallback_returns_cached_clean_frame(self): + w = _make_wrapper(fallback_type="previous") + call_count = [0] + + def checker(t, thr): + call_count[0] += 1 + return call_count[0] > 1 # first call: clean; second call: flagged + + w.safety_checker = checker + + clean = torch.randn(1, 3, 64, 64) + _ = w._apply_safety_checker(clean) # primes _prev_clean_tensor + + flagged = torch.randn(1, 3, 64, 64) + result = w._apply_safety_checker(flagged) + + assert result is w._prev_clean_tensor, "previous-fallback should return the cached clean tensor" + + # ── Case 4: Clean frame passthrough ────────────────────────────────────── + def test_clean_frame_returned_unchanged(self): + w = _make_wrapper() + w.safety_checker = lambda t, thr: False # never flag + + dummy = torch.randn(1, 3, 64, 64) + result = w._apply_safety_checker(dummy) + + assert torch.equal(result, dummy), "clean frame should be returned unchanged" + + # ── Case 5: use_safety_checker=False bypasses entirely ─────────────────── + def test_disabled_bypasses_checker(self): + called = [False] + + def should_not_be_called(t, thr): + called[0] = True + return False + + w = _make_wrapper(use_safety_checker=False) + w.safety_checker = should_not_be_called + + dummy = torch.randn(1, 3, 64, 64) + result = w._apply_safety_checker(dummy) + + assert not called[0], "safety checker must not be called when disabled" + assert torch.equal(result, dummy), "disabled checker should return input unchanged" + + # ── Case 6: previous fallback with no cached frame → black ─────────────── + def test_nsfw_previous_fallback_no_cache_falls_back_to_black(self): + """ + If the very first frame is flagged and _prev_clean_tensor is None, + the fallback should still produce a black frame rather than raise. + """ + w = _make_wrapper(fallback_type="previous") + w.safety_checker = lambda t, thr: True # flag immediately + assert w._prev_clean_tensor is None # no cache yet + + dummy = torch.randn(1, 3, 64, 64) + result = w._apply_safety_checker(dummy) + + denorm = (result / 2 + 0.5).clamp(0, 1) + assert _black_denorm(denorm), "when prev cache is empty and content is flagged, fallback must be black" + + # ── Case 7: clean frame caches for previous strategy ───────────────────── + def test_clean_frame_cached_for_previous_strategy(self): + w = _make_wrapper(fallback_type="previous") + w.safety_checker = lambda t, thr: False + + assert w._prev_clean_tensor is None + dummy = torch.randn(1, 3, 64, 64) + w._apply_safety_checker(dummy) + + assert w._prev_clean_tensor is not None, ( + "clean frame with previous strategy should populate _prev_clean_tensor" + ) + + # ── Case 8: clean frame NOT cached for blank strategy ──────────────────── + def test_clean_frame_not_cached_for_blank_strategy(self): + w = _make_wrapper(fallback_type="blank") + w.safety_checker = lambda t, thr: False + + dummy = torch.randn(1, 3, 64, 64) + w._apply_safety_checker(dummy) + + assert w._prev_clean_tensor is None, "_prev_clean_tensor should stay None when fallback_type='blank'" + + # ── Case 9: _process_skip_diffusion wiring — checker is called, result is black ── + def test_skip_diffusion_routes_through_safety_checker(self): + """ + Verify that _process_skip_diffusion actually feeds its tensor through + _apply_safety_checker before postprocess_image, and that a flagged frame + produces the black substitution rather than passing through unscreened. + + This is a wiring test — the behaviour of _apply_safety_checker itself is + covered by the earlier cases. We stub all model-dependent calls with + identity/no-op lambdas so the test is CPU-only and model-free. + """ + received: list = [] + + def capturing_checker(tensor, thr): + received.append(tensor) + return True # always flag + + w = _make_wrapper(fallback_type="blank") + w.safety_checker = capturing_checker + w.mode = "img2img" + w.device = torch.device("cpu") + w.dtype = torch.float32 + + # Stub the stream's pre/post hooks to identity + class _FakeStream: + def _apply_image_preprocessing_hooks(self, t): + return t + + def _apply_image_postprocessing_hooks(self, t): + return t + + w.stream = _FakeStream() + + # _normalize_on_gpu / _denormalize_on_gpu are identity for this test + w._normalize_on_gpu = lambda t: t + w._denormalize_on_gpu = lambda t: t + + # postprocess_image returns its input so we can inspect the tensor + w.postprocess_image = lambda t, output_type=None: t + w.output_type = "pt" + + dummy_image = torch.randn(1, 3, 64, 64) + result = w._process_skip_diffusion(dummy_image) + + # Checker must have been called exactly once with a real tensor + assert len(received) == 1, "safety checker should be called exactly once" + assert isinstance(received[0], torch.Tensor), "checker arg must be a Tensor, not None" + + # Flagged frame must produce the black substitution (all -1.0 → denorm 0.0) + denorm = (result / 2 + 0.5).clamp(0, 1) + assert torch.allclose(denorm, torch.zeros_like(denorm)), ( + "flagged skip-diffusion frame should produce a black output" + )