Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions src/gen_worker/pipeline_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,12 +1525,39 @@ async def load(
# Move to device with OOM handling
try:
if config.device == "cuda" and torch.cuda.is_available():
logger.info(
"Moving pipeline to CUDA for %s (%.1f GB) ...",
model_id,
model_size_gb,
)
pipeline = pipeline.to("cuda")
logger.info("Pipeline moved to CUDA successfully for %s", model_id)
else:
logger.warning(
"CUDA not available (device=%s, cuda.is_available=%s), "
"pipeline will remain on CPU for %s",
config.device,
torch.cuda.is_available(),
model_id,
)
except torch.cuda.OutOfMemoryError as e:
flush_memory()
raise CudaOutOfMemoryError(
model_id, model_size_gb, get_available_vram_gb()
) from e
except RuntimeError as e:
logger.error(
"CUDA RuntimeError moving %s to GPU: %s — falling back to CPU",
model_id,
e,
)
except Exception as e:
logger.error(
"Unexpected error moving %s to GPU (%s: %s) — falling back to CPU",
model_id,
type(e).__name__,
e,
)

# Apply VAE optimizations (always enabled)
if config.enable_vae_tiling or config.enable_vae_slicing:
Expand Down
17 changes: 17 additions & 0 deletions src/gen_worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4120,6 +4120,7 @@ async def consume_async() -> None:
logger.info("Task %s completed successfully.", request_id)

except Exception as e:
logger.exception("Task %s failed: %s", request_id, e)
error_type, retryable, safe_message, error_message = self._map_exception(e)
if inference_watchdog is not None:
inference_watchdog.cancel()
Expand Down Expand Up @@ -4666,7 +4667,15 @@ def _resolve_injected_value(self, ctx: RequestContext, requested_type: Any, mode
model_source = str(model_id)
preload_kwargs = {}

logger.info(
"Loading from_pretrained: source=%s type=%s kwargs=%s",
model_source, type_qualname(requested_type), list(preload_kwargs.keys()),
)
obj = from_pretrained(model_source, **preload_kwargs)
logger.info(
"from_pretrained complete: source=%s (%.1fs)",
model_source, time.monotonic() - t_pi0,
)
if rm is not None:
rm.add_pipeline_init_time(int((time.monotonic() - t_pi0) * 1000))
if isinstance(requested_type, type) and not isinstance(obj, requested_type):
Expand Down Expand Up @@ -4704,6 +4713,10 @@ def _resolve_injected_value(self, ctx: RequestContext, requested_type: Any, mode
torch_dtype = kwargs.get("torch_dtype") if isinstance(kwargs, dict) else None
except Exception:
torch_dtype = None
logger.info(
"Moving model to device=%s dtype=%s model=%s ...",
str(ctx.device), torch_dtype, model_id,
)
try:
if torch_dtype is not None:
obj = obj.to(str(ctx.device), dtype=torch_dtype)
Expand All @@ -4712,6 +4725,10 @@ def _resolve_injected_value(self, ctx: RequestContext, requested_type: Any, mode
except TypeError:
# Some objects implement .to(device) but not dtype kwarg.
obj = obj.to(str(ctx.device))
logger.info(
"Model moved to device=%s successfully model=%s",
str(ctx.device), model_id,
)
if rm is not None:
rm.add_gpu_load_time(int((time.monotonic() - t_to0) * 1000))

Expand Down
Loading