Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/streamdiffusion/acceleration/tensorrt/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from .models.models import BaseModel
from .utilities import (
TRT_LOGGER,
build_engine,
export_onnx,
optimize_onnx,
Expand Down Expand Up @@ -291,7 +292,7 @@ def _quant_fn():

import tensorrt as trt

_rt = trt.Runtime(trt.Logger(trt.Logger.WARNING))
_rt = trt.Runtime(TRT_LOGGER)
with open(engine_path, "rb") as _f:
_eng = _rt.deserialize_cuda_engine(_f.read())
_insp = _eng.create_engine_inspector()
Expand Down
14 changes: 6 additions & 8 deletions src/streamdiffusion/acceleration/tensorrt/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@
from polygraphy import cuda
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import engine_from_bytes
from polygraphy.backend.trt.util import get_trt_logger

from .models.models import CLIP, VAE, BaseModel, UNet, VAEEncoder


logger = logging.getLogger(__name__)

TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
TRT_LOGGER = get_trt_logger() # polygraphy singleton — shared with engine_from_bytes()

from ...model_detection import detect_model

Expand Down Expand Up @@ -624,13 +625,12 @@ def build(
# set_preview_feature, or SPARSE_WEIGHTS. We use the raw API (same as
# the FP8 path) so all parameters are available for both precision paths.

build_logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(build_logger)
builder = trt.Builder(TRT_LOGGER)

network_flags = 0
network = builder.create_network(network_flags)

parser = trt.OnnxParser(network, build_logger)
parser = trt.OnnxParser(network, TRT_LOGGER)
parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM)
success = parser.parse_from_file(onnx_path)
if not success:
Expand Down Expand Up @@ -732,16 +732,14 @@ def _build_fp8(
gpu_profile: Hardware-aware build parameters from detect_gpu_profile().
dynamic_shapes: Whether the engine uses dynamic input shapes.
"""
build_logger = trt.Logger(trt.Logger.WARNING)

builder = trt.Builder(build_logger)
builder = trt.Builder(TRT_LOGGER)

# STRONGLY_TYPED: required for FP8. Tells TRT to use the data-type annotations
# from Q/DQ nodes rather than running its own precision heuristics.
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
network = builder.create_network(network_flags)

parser = trt.OnnxParser(network, build_logger)
parser = trt.OnnxParser(network, TRT_LOGGER)
# NATIVE_INSTANCENORM: use TRT's fused InstanceNorm/GroupNorm kernel instead
# of decomposing into primitive ops. Diffusion UNets use GroupNorm heavily.
parser.set_flag(trt.OnnxParserFlag.NATIVE_INSTANCENORM)
Expand Down
4 changes: 2 additions & 2 deletions src/streamdiffusion/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,8 @@ def _validate_config(config: Dict[str, Any]) -> None:
if not isinstance(seed_value, int) or seed_value < 0:
raise ValueError(f"_validate_config: Seed value {i} must be a non-negative integer")

if not isinstance(weight, (int, float)) or weight < 0:
raise ValueError(f"_validate_config: Seed weight {i} must be a non-negative number")
if not isinstance(weight, (int, float)) or weight < 0:
raise ValueError(f"_validate_config: Seed weight {i} must be a non-negative number")

interpolation_method = seed_blend_config.get("interpolation_method", "linear")
if interpolation_method not in ["linear", "slerp"]:
Expand Down
16 changes: 11 additions & 5 deletions src/streamdiffusion/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def detect_model(model: torch.nn.Module, pipe: Optional[Any] = None) -> Dict[str
return result


def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, any]:
def detect_unet_characteristics(unet: UNet2DConditionModel) -> Dict[str, Any]:
"""Detect detailed UNet characteristics including SDXL-specific features"""
config = unet.config

Expand Down Expand Up @@ -246,12 +246,12 @@ def detect_model_from_diffusers_unet(unet: UNet2DConditionModel) -> str:
return "SD21"

if cross_attention_dim == 768:
print(
logger.warning(
f"detect_model_from_diffusers_unet: Unknown SD1.5-like model with channels {block_out_channels}, defaulting to SD15"
)
return "SD15"
elif cross_attention_dim == 1024:
print(
logger.warning(
f"detect_model_from_diffusers_unet: Unknown SD2.1-like model with channels {block_out_channels}, defaulting to SD21"
)
return "SD21"
Expand Down Expand Up @@ -379,12 +379,18 @@ def validate_architecture(arch_dict: Dict[str, Any], model_type: str) -> Dict[st
else:
arch_dict[key] = tuple(arch_dict[key])
else:
arch_dict[key] = preset[key]
raise ValueError(
f"validate_architecture: '{key}' has unsupported type {type(arch_dict[key]).__name__!r}; "
f"expected list, int, or tuple"
)

# Validate sequence lengths match
expected_levels = len(arch_dict["channel_mult"])
for key in ["num_res_blocks", "transformer_depth"]:
if key in arch_dict and len(arch_dict[key]) != expected_levels:
arch_dict[key] = preset[key]
raise ValueError(
f"validate_architecture: '{key}' has {len(arch_dict[key])} levels but "
f"'channel_mult' has {expected_levels}; they must match"
)

return arch_dict
4 changes: 2 additions & 2 deletions src/streamdiffusion/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
normalize_seed_weights: bool = True,
scheduler: Literal["lcm", "tcd"] = "lcm",
sampler: Literal["simple", "sgm uniform", "normal", "ddim", "beta", "karras"] = "normal",
kvo_cache: List[torch.Tensor] = [],
kvo_cache: Optional[List[torch.Tensor]] = None,
cache_interval: int = 1,
cache_maxframes: int = 1,
) -> None:
Expand Down Expand Up @@ -173,7 +173,7 @@ def __init__(
self._cached_cfg_type: Optional[str] = None
self._cached_guidance_scale: Optional[float] = None

self.kvo_cache = kvo_cache
self.kvo_cache = kvo_cache if kvo_cache is not None else []
self._kvo_buckets = None
self._kvo_outputs_by_bucket = None
self.cache_interval = cache_interval
Expand Down
82 changes: 40 additions & 42 deletions src/streamdiffusion/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,11 @@ def __init__(
self.use_denoising_batch = use_denoising_batch
# safety checker is only supported for TensorRT acceleration
self.use_safety_checker = use_safety_checker and (acceleration == "tensorrt")
self.set_nsfw_fallback_img(height, width)
self.safety_checker_fallback_type = safety_checker_fallback_type
self.safety_checker_threshold = safety_checker_threshold
# Caches the last clean (non-flagged) pipeline tensor for the "previous" fallback strategy.
# Operates in diffusion range [-1, 1]; set by _apply_safety_checker().
self._prev_clean_tensor: Optional[torch.Tensor] = None
self.fp8 = fp8
self.static_shapes = static_shapes
self.fp8_allow_fp16_fallback = fp8_allow_fp16_fallback
Expand Down Expand Up @@ -770,8 +772,6 @@ def _process_skip_diffusion(
The processed image with hooks applied.
"""

# TODO: add safety checker call somewhere in this method

if self.mode == "txt2img":
raise RuntimeError(
"_process_skip_diffusion: skip_diffusion mode not applicable for txt2img - no input image"
Expand All @@ -798,6 +798,9 @@ def _process_skip_diffusion(
# Apply image postprocessing hooks (expect [-1,1] range - post-VAE decoding)
processed_tensor = self.stream._apply_image_postprocessing_hooks(processed_tensor)

# Screen skip-diffusion output too (raw [-1, 1] tensor, before postprocess/IPC export).
processed_tensor = self._apply_safety_checker(processed_tensor)

# Final postprocessing for output format
return self.postprocess_image(processed_tensor, output_type=self.output_type)

Expand All @@ -824,18 +827,9 @@ def txt2img(self, prompt: Optional[str] = None) -> Union[Image.Image, List[Image
else:
image_tensor = self.stream.txt2img(self.frame_buffer_size)

image_tensor = self._apply_safety_checker(image_tensor)
image = self.postprocess_image(image_tensor, output_type=self.output_type)

if self.use_safety_checker:
if self.output_type != "pt":
denormalized_image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1).to(self.device)
else:
denormalized_image_tensor = image
if self.safety_checker(denormalized_image_tensor, self.safety_checker_threshold):
image = self.nsfw_fallback_img
elif self.safety_checker_fallback_type == "previous":
self.nsfw_fallback_img = image

return image

def img2img(
Expand Down Expand Up @@ -865,17 +859,8 @@ def img2img(

# Full pipeline with diffusion
image_tensor = self.stream(image)
image_tensor = self._apply_safety_checker(image_tensor)
image = self.postprocess_image(image_tensor, output_type=self.output_type)
if self.use_safety_checker:
if self.output_type != "pt":
denormalized_image_tensor = (image_tensor / 2 + 0.5).clamp(0, 1).to(self.device)
else:
denormalized_image_tensor = image
if self.safety_checker(denormalized_image_tensor, self.safety_checker_threshold):
image = self.nsfw_fallback_img
logger.info(f"NSFW content detected, falling back to {self.nsfw_fallback_img} frame")
elif self.safety_checker_fallback_type == "previous":
self.nsfw_fallback_img = image

return image

Expand Down Expand Up @@ -1050,7 +1035,7 @@ def cleanup_cuda_ipc(self) -> None:
try:
self._cuda_ipc_exporter.close()
except Exception:
pass
logger.debug("cleanup_cuda_ipc: _cuda_ipc_exporter.close() failed", exc_info=True)
self._cuda_ipc_exporter = None

def _denormalize_on_gpu(self, image_tensor: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -1136,29 +1121,42 @@ def _tensor_to_pil_optimized(self, image_tensor: torch.Tensor) -> List[Image.Ima

return pil_images

def set_nsfw_fallback_img(self, height: int, width: int) -> None:
"""
Set the NSFW fallback image used when safety checker blocks content.
def _apply_safety_checker(self, image_tensor: torch.Tensor) -> torch.Tensor:
"""Run the NSFW check on the raw pipeline tensor and substitute a fallback if flagged.

Creates a black RGB image of the specified dimensions that will be returned
when the safety checker determines content should be blocked.
This MUST be called *before* postprocess_image so that the substitution also covers the
CUDA-IPC export path: postprocess_image() exports the frame inside its own body and
returns None, making any post-hoc substitution unreachable and unsafe.

Parameters
----------
height : int
Height of the fallback image in pixels.
width : int
Width of the fallback image in pixels.
image_tensor : torch.Tensor
Raw pipeline output in diffusion range [-1, 1], NCHW.

Returns
-------
None
torch.Tensor
The original tensor when content is clean, or a fallback tensor (previous clean
frame, or all-black encoded as -1.0 in diffusion range) when content is flagged.
"""
self.nsfw_fallback_img = Image.new("RGB", (height, width), (0, 0, 0))
if self.output_type == "pt":
self.nsfw_fallback_img = torch.from_numpy(np.array(self.nsfw_fallback_img)).unsqueeze(0)
elif self.output_type == "np":
self.nsfw_fallback_img = np.expand_dims(np.array(self.nsfw_fallback_img), axis=0)
if not self.use_safety_checker:
return image_tensor

# Denormalize to [0, 1] NCHW for the classifier; stays on GPU.
denormalized = self._denormalize_on_gpu(image_tensor)

if self.safety_checker(denormalized, self.safety_checker_threshold):
logger.info("NSFW content detected, applying safety fallback frame")
if self.safety_checker_fallback_type == "previous" and self._prev_clean_tensor is not None:
return self._prev_clean_tensor
# -1.0 in diffusion range → 0.0 after denormalization → true black on every output
# path (pt, np, pil, CUDA-IPC).
return torch.full_like(image_tensor, -1.0)

# Content is clean — cache it for the "previous" fallback strategy.
if self.safety_checker_fallback_type == "previous":
self._prev_clean_tensor = image_tensor.clone()
return image_tensor

def _load_model(
self,
Expand Down Expand Up @@ -1522,7 +1520,7 @@ def _load_model(
try:
stream.pipe.unload_lora_weights()
except Exception:
pass
logger.debug("LoRA cleanup: unload_lora_weights() failed after merge failure", exc_info=True)

if use_tiny_vae:
if vae_id is not None:
Expand Down Expand Up @@ -2603,7 +2601,7 @@ def cleanup_gpu_memory(self) -> None:
self.stream._param_updater.clear_caches()
logger.info(" Cleared prompt caches")
except Exception:
pass
logger.debug("cleanup_gpu_memory: clear_caches() failed", exc_info=True)

# Enhanced TensorRT engine cleanup
if hasattr(self, "stream") and self.stream:
Expand Down Expand Up @@ -2678,7 +2676,7 @@ def cleanup_gpu_memory(self) -> None:
del self.stream
logger.info(" Cleared stream object")
except Exception:
pass
logger.debug("cleanup_gpu_memory: del self.stream failed", exc_info=True)
self.stream = None

# Release wrapper-level frame buffers so the next model swap allocates fresh
Expand Down
Loading