Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "gen-worker"
version = "0.3.10"
version = "0.3.7"
description = "A library used to build custom functions in Cozy Creator's serverless function platform."
readme = "README.md"
license = "MIT"
Expand Down
245 changes: 111 additions & 134 deletions src/gen_worker/cozy_cas.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,159 +307,136 @@ def _safe_symlink_dir(target: Path, link: Path) -> None:
max_value=30, # cap backoff at 30s between retries
)
async def _download_one_file(url: str, dst: Path, expected_size: int, expected_blake3: str) -> None:
import fcntl
"""Download a single file with HTTP Range resume, size + blake3 validation.

Fully async — no blocking calls that would stall the event loop.
Caller is responsible for ensuring only one coroutine downloads a given
dst at a time (dedup by digest in _ensure_blobs).
"""
import logging
log = logging.getLogger("gen_worker.download")

log.info("download_start path=%s expected_size=%s expected_blake3=%s", dst.name, expected_size, (expected_blake3 or "")[:16])
print(f"DEBUG download_start path={dst.name} expected_size={expected_size} expected_blake3={(expected_blake3 or '')[:16]}")
def _human_size(n: int) -> str:
if n >= 1 << 30:
return f"{n / (1 << 30):.1f}GB"
if n >= 1 << 20:
return f"{n / (1 << 20):.1f}MB"
if n >= 1 << 10:
return f"{n / (1 << 10):.1f}KB"
return f"{n}B"

# Already downloaded and valid?
if dst.exists():
log.info("dst_exists path=%s size=%s", dst, dst.stat().st_size)
print(f"DEBUG dst_exists path={dst} size={dst.stat().st_size}")
try:
if expected_size and dst.stat().st_size != expected_size:
raise ValueError("size mismatch")
if expected_blake3:
got = _blake3_file(dst)
if got.lower() != expected_blake3.lower():
raise ValueError("blake3 mismatch")
log.info("download_cached path=%s size=%s", dst.name, _human_size(dst.stat().st_size))
return
except Exception:
# Fall through to re-download.
pass

# Use sock_read instead of total timeout so actively-streaming large files
# are not killed. total=None lets multi-GB downloads run as long as data
# keeps flowing; sock_read=120 catches genuine stalls.
timeout = aiohttp.ClientTimeout(total=None, sock_connect=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_CONNECT_TIMEOUT_S", "60")),
sock_read=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_READ_TIMEOUT_S", "180")))
pass # re-download

timeout = aiohttp.ClientTimeout(
total=None,
sock_connect=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_CONNECT_TIMEOUT_S", "60")),
sock_read=float(os.getenv("WORKER_MODEL_DOWNLOAD_SOCK_READ_TIMEOUT_S", "180")),
)
tmp = dst.with_suffix(dst.suffix + ".part")
lock_path = dst.with_suffix(dst.suffix + ".lock")

# File-level exclusive lock: prevents concurrent writes to the same .part
# file even from different async tasks or downloader instances.
lock_path.parent.mkdir(parents=True, exist_ok=True)
lock_fd = open(lock_path, "w")
try:
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX)
print(f"DEBUG file_lock_acquired path={dst.name}")

# Re-check dst after acquiring the lock — another holder might have
# already completed the download while we waited.
if dst.exists():
try:
if expected_size and dst.stat().st_size != expected_size:
raise ValueError("size mismatch after lock")
if expected_blake3:
got = _blake3_file(dst)
if got.lower() != expected_blake3.lower():
raise ValueError("blake3 mismatch after lock")
print(f"DEBUG file_lock_dst_completed path={dst.name} (another writer finished)")
return
except Exception:
pass

# If we have a partial file, try to resume via HTTP Range.
offset = 0
if tmp.exists():
try:
offset = tmp.stat().st_size
except OSError:
# Another coroutine may have renamed tmp→dst between the exists() check and stat().
offset = 0
if offset:
log.info("resume_attempt path=%s offset=%s expected_size=%s", dst.name, offset, expected_size)
print(f"DEBUG resume_attempt path={dst.name} offset={offset} expected_size={expected_size}")
if expected_size and offset > expected_size:
tmp.unlink(missing_ok=True)
offset = 0

# If the partial file is already complete, validate + finalize.
if offset and expected_size and offset == expected_size:
got = _blake3_file(tmp)
if expected_blake3 and got.lower() != expected_blake3.lower():
tmp.unlink(missing_ok=True)
else:
tmp.rename(dst)
return

headers: Dict[str, str] = {}
mode = "wb"
if offset and expected_size:
headers["Range"] = f"bytes={offset}-"
mode = "ab"
print(f"DEBUG range_header path={dst.name} Range=bytes={offset}- mode={mode}")

async def _stream_to_file(resp: aiohttp.ClientResponse, *, mode: str, start: int) -> None:
nonlocal expected_size
size = start
with open(tmp, mode) as f:
async for chunk in resp.content.iter_chunked(1 << 20):
if not chunk:
continue
f.write(chunk)
size += len(chunk)
if expected_size and size > expected_size:
raise ValueError("download exceeded expected size")

async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as resp:
content_range = str(resp.headers.get("Content-Range") or "").strip()
print(f"DEBUG http_response path={dst.name} status={resp.status} content_range={content_range!r} content_length={resp.headers.get('Content-Length', 'unknown')} offset={offset}")
# If the server ignored our Range request, restart from scratch to avoid
# duplicating bytes by appending a full response.
# Some gateways can return 206 with an unexpected range start.
# Treat that the same as a 200-on-resume and restart from byte 0.
if offset and (
resp.status == 200
or (
resp.status == 206
and not content_range.startswith(f"bytes {offset}-")
)
):
print(f"DEBUG range_ignored path={dst.name} status={resp.status} content_range={content_range!r} restarting_from_zero=True")
resp.release()
async with session.get(url) as resp2:
resp2.raise_for_status()
print(f"DEBUG range_restart path={dst.name} status={resp2.status} content_length={resp2.headers.get('Content-Length', 'unknown')}")
await _stream_to_file(resp2, mode="wb", start=0)
else:
resp.raise_for_status()
await _stream_to_file(resp, mode=mode, start=offset)

# Validate final file.
actual_size = tmp.stat().st_size
log.info("download_complete path=%s actual_size=%s expected_size=%s", dst.name, actual_size, expected_size)
print(f"DEBUG download_complete path={dst.name} actual_size={actual_size} expected_size={expected_size}")
if expected_size and actual_size != expected_size:
log.error("size_mismatch path=%s expected=%s got=%s url=%s", dst.name, expected_size, actual_size, url[:80])
print(f"DEBUG size_mismatch path={dst.name} expected={expected_size} got={actual_size} url={url[:80]}")
tmp.unlink(missing_ok=True)
raise ValueError(f"size mismatch (expected {expected_size}, got {actual_size})")
if expected_blake3:
got = _blake3_file(tmp)
log.info("blake3_check path=%s expected=%s got=%s", dst.name, (expected_blake3 or "")[:16], got[:16])
print(f"DEBUG blake3_check path={dst.name} expected={(expected_blake3 or '')[:16]} got={got[:16]}")
if got.lower() != expected_blake3.lower():
log.error("blake3_mismatch path=%s", dst.name)
print(f"DEBUG blake3_mismatch path={dst.name}")
tmp.unlink(missing_ok=True)
raise ValueError("blake3 mismatch")
# A concurrent coroutine may have already renamed tmp→dst (won the race).
# Use an atomic replace so we don't fail if dst now exists.
# Resume from partial download if available.
offset = 0
if tmp.exists():
try:
tmp.replace(dst)
offset = tmp.stat().st_size
except OSError:
# dst was created by another coroutine; .part is stale, just remove it.
offset = 0
if expected_size and offset > expected_size:
tmp.unlink(missing_ok=True)
finally:
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_UN)
lock_fd.close()
try:
lock_path.unlink(missing_ok=True)
except OSError:
pass
offset = 0

# Partial file already complete? Validate and finalize.
if offset and expected_size and offset == expected_size:
got = _blake3_file(tmp)
if expected_blake3 and got.lower() != expected_blake3.lower():
log.warning("partial_corrupt path=%s (blake3 mismatch, restarting)", dst.name)
tmp.unlink(missing_ok=True)
offset = 0
else:
tmp.rename(dst)
log.info("download_resumed_complete path=%s size=%s", dst.name, _human_size(expected_size))
return

headers: Dict[str, str] = {}
mode = "wb"
if offset and expected_size:
headers["Range"] = f"bytes={offset}-"
mode = "ab"
log.info("download_resume path=%s offset=%s/%s (%s/%s)",
dst.name, offset, expected_size,
_human_size(offset), _human_size(expected_size))
else:
log.info("download_start path=%s size=%s blake3=%s",
dst.name, _human_size(expected_size) if expected_size else "unknown",
(expected_blake3 or "n/a")[:16])

async def _stream(resp: aiohttp.ClientResponse, *, write_mode: str, start: int) -> None:
downloaded = start
last_log = start
log_every = max(expected_size // 10, 50 << 20) if expected_size else (100 << 20)
with open(tmp, write_mode) as f:
async for chunk in resp.content.iter_chunked(1 << 20):
if not chunk:
continue
f.write(chunk)
downloaded += len(chunk)
if expected_size and downloaded > expected_size:
raise ValueError(f"download exceeded expected size ({downloaded} > {expected_size})")
if downloaded - last_log >= log_every:
pct = f" ({100 * downloaded // expected_size}%)" if expected_size else ""
log.info("download_progress path=%s downloaded=%s%s",
dst.name, _human_size(downloaded), pct)
last_log = downloaded

async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as resp:
content_range = str(resp.headers.get("Content-Range") or "").strip()

# Server ignored Range or returned unexpected range start?
# Restart from byte 0 to avoid corrupted appends.
if offset and (
resp.status == 200
or (resp.status == 206 and not content_range.startswith(f"bytes {offset}-"))
):
log.info("download_range_ignored path=%s status=%s (restarting from 0)", dst.name, resp.status)
resp.release()
async with session.get(url) as resp2:
resp2.raise_for_status()
await _stream(resp2, write_mode="wb", start=0)
else:
resp.raise_for_status()
await _stream(resp, write_mode=mode, start=offset)

# Validate.
actual_size = tmp.stat().st_size
if expected_size and actual_size != expected_size:
log.error("download_size_mismatch path=%s expected=%s got=%s", dst.name, expected_size, actual_size)
tmp.unlink(missing_ok=True)
raise ValueError(f"size mismatch (expected {expected_size}, got {actual_size})")

if expected_blake3:
got = _blake3_file(tmp)
if got.lower() != expected_blake3.lower():
log.error("download_blake3_mismatch path=%s expected=%s got=%s",
dst.name, expected_blake3[:16], got[:16])
tmp.unlink(missing_ok=True)
raise ValueError("blake3 mismatch")

# Atomic finalize.
tmp.replace(dst)
log.info("download_done path=%s size=%s", dst.name, _human_size(actual_size))


def _blake3_file(path: Path, chunk_size: int = 1 << 20) -> str:
Expand Down
57 changes: 14 additions & 43 deletions src/gen_worker/cozy_pipeline_spec.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,16 @@
from __future__ import annotations

import json
import logging
import os
import tomllib
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import yaml

logger = logging.getLogger(__name__)

COZY_PIPELINE_LOCK_FILENAME = "cozy.pipeline.lock.yaml"
COZY_PIPELINE_FILENAME = "cozy.pipeline.yaml"
PIPELINE_LOCK_TOML_FILENAME = "pipeline.lock"
PIPELINE_TOML_FILENAME = "pipeline.toml"
DIFFUSERS_MODEL_INDEX_FILENAME = "model_index.json"


Expand All @@ -38,13 +33,6 @@ def custom_pipeline_path(self) -> Optional[str]:
s = str(v).strip()
return s or None

@property
def variant(self) -> Optional[str]:
"""Diffusers variant (e.g. 'fp16', 'fp8') from the pipeline spec."""
pipe = self.raw.get("pipe") or {}
v = str(pipe.get("variant") or "").strip()
return v or None


def _safe_child_path(root: Path, rel: str) -> Path:
# Ensure rel doesn't escape root (best-effort).
Expand All @@ -62,41 +50,24 @@ def load_cozy_pipeline_spec(model_root: Path) -> Optional[CozyPipelineSpec]:
This is a worker-side helper used during pipeline loading to implement:
- prefer `cozy.pipeline.lock.yaml` when present
- fall back to `cozy.pipeline.yaml` otherwise
- fall back to `pipeline.lock` / `pipeline.toml` (TOML) if no YAML found
"""
root = Path(model_root)
lock_path = root / COZY_PIPELINE_LOCK_FILENAME
spec_path = lock_path if lock_path.exists() else (root / COZY_PIPELINE_FILENAME)
if spec_path.exists():
raw = yaml.safe_load(spec_path.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("invalid cozy pipeline spec (expected mapping)")
api = str(raw.get("apiVersion") or "").strip()
kind = str(raw.get("kind") or "").strip()
if api and api != "v1":
raise ValueError(f"unsupported cozy pipeline apiVersion: {api!r}")
if kind and kind != "DiffusersPipeline":
raise ValueError(f"unsupported cozy pipeline kind: {kind!r}")
logger.info("DEBUG loaded cozy pipeline spec from %s", spec_path.name)
return CozyPipelineSpec(source_path=spec_path, raw=raw)

# Fallback: read pipeline.lock / pipeline.toml (TOML format, stored by tensorhub ingest).
toml_lock = root / PIPELINE_LOCK_TOML_FILENAME
toml_spec = toml_lock if toml_lock.exists() else (root / PIPELINE_TOML_FILENAME)
if toml_spec.exists():
raw = tomllib.loads(toml_spec.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("invalid pipeline toml (expected mapping)")
api = str(raw.get("apiVersion") or "").strip()
kind = str(raw.get("kind") or "").strip()
if api and api != "v1":
raise ValueError(f"unsupported pipeline toml apiVersion: {api!r}")
if kind and kind != "DiffusersPipeline":
raise ValueError(f"unsupported pipeline toml kind: {kind!r}")
logger.info("DEBUG loaded cozy pipeline spec from %s (toml fallback)", toml_spec.name)
return CozyPipelineSpec(source_path=toml_spec, raw=raw)

return None
if not spec_path.exists():
return None

raw = yaml.safe_load(spec_path.read_text(encoding="utf-8"))
if not isinstance(raw, dict):
raise ValueError("invalid cozy pipeline spec (expected mapping)")
api = str(raw.get("apiVersion") or "").strip()
kind = str(raw.get("kind") or "").strip()
if api and api != "v1":
raise ValueError(f"unsupported cozy pipeline apiVersion: {api!r}")
if kind and kind != "DiffusersPipeline":
raise ValueError(f"unsupported cozy pipeline kind: {kind!r}")

return CozyPipelineSpec(source_path=spec_path, raw=raw)


def cozy_custom_pipeline_arg(model_root: Path, spec: CozyPipelineSpec) -> Optional[str]:
Expand Down
Loading
Loading