Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/ollama_queue_proxy/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
"transfer-encoding",
}

# Headers to strip from upstream before building a non-streaming JSONResponse.
# JSONResponse re-serialises the body (dropping any trailing newline Ollama appends),
# so it must compute content-length itself — passing the upstream value causes an
# off-by-one. transfer-encoding is hop-by-hop and invalid on a buffered response.
_STRIP_RESPONSE_HEADERS = {"content-length", "transfer-encoding"}


def extract_model(body: bytes) -> str | None:
"""Extract the 'model' field from a JSON request body."""
Expand Down Expand Up @@ -244,10 +250,14 @@ async def stream_gen(r=resp):
)
else:
ct = resp.headers.get("content-type", "")
passthrough_headers = {
k: v for k, v in resp.headers.items()
if k.lower() not in _STRIP_RESPONSE_HEADERS
}
return JSONResponse(
status_code=resp.status_code,
content=resp.json() if ct.startswith("application/json") else None,
headers={**dict(resp.headers), **response_headers},
headers={**passthrough_headers, **response_headers},
)

except (httpx.ConnectError, httpx.TimeoutException, httpx.HTTPStatusError) as e:
Expand Down
74 changes: 74 additions & 0 deletions tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from __future__ import annotations

import json
from unittest.mock import AsyncMock, MagicMock, patch

import httpx
import pytest

from ollama_queue_proxy.proxy import extract_model, _MODEL_MANAGEMENT_PATHS

Expand Down Expand Up @@ -39,3 +43,73 @@ def test_generate_not_in_management_paths():

def test_chat_not_in_management_paths():
assert "/api/chat" not in _MODEL_MANAGEMENT_PATHS


# ---------------------------------------------------------------------------
# OQP-1: Content-Length off-by-one on non-streaming chat completions
# ---------------------------------------------------------------------------

@pytest.mark.asyncio
async def test_non_streaming_response_content_length_correct():
"""
dispatch_request must NOT forward the upstream content-length to JSONResponse.
Ollama appends a trailing newline to non-streaming JSON bodies, so the upstream
content-length is 1 byte longer than the re-serialised response body.
The returned response must carry the correct (re-serialised) content-length.
"""
from fastapi import Request
from ollama_queue_proxy.proxy import dispatch_request
from ollama_queue_proxy.hosts import HostManager, OllamaHost
from tests.conftest import make_config

payload = {"message": {"role": "assistant", "content": "hi"}, "done": True}
# Ollama appends \n — body is 1 byte longer than the JSON-only serialisation
upstream_body = json.dumps(payload).encode() + b"\n"
upstream_content_length = str(len(upstream_body)) # e.g. "52"

mock_resp = MagicMock(spec=httpx.Response)
mock_resp.status_code = 200
mock_resp.headers = httpx.Headers({
"content-type": "application/json",
"content-length": upstream_content_length,
})
mock_resp.json.return_value = payload

mock_client = AsyncMock(spec=httpx.AsyncClient)
mock_client.request = AsyncMock(return_value=mock_resp)

cfg = make_config()
host = OllamaHost(url="http://ollama-test:11434", name="test")
host.healthy = True
hm = HostManager.__new__(HostManager)
hm.hosts = [host]

scope = {
"type": "http",
"method": "POST",
"path": "/api/chat",
"query_string": b"",
"headers": [(b"content-type", b"application/json")],
}
request = Request(scope)
request.state.request_id = "test-req"

response = await dispatch_request(
request=request,
body=json.dumps({"model": "llama3", "messages": []}).encode(),
client_id=None,
config=cfg,
host_manager=hm,
client=mock_client,
)

# JSONResponse uses compact separators — match that serialisation to get the correct length
expected_body = json.dumps(payload, ensure_ascii=False, allow_nan=False,
indent=None, separators=(",", ":")).encode("utf-8")
assert response.headers["content-length"] == str(len(expected_body)), (
f"content-length should be {len(expected_body)} (re-serialised body), "
f"not {upstream_content_length} (upstream body with trailing newline)"
)
assert response.headers["content-length"] != upstream_content_length, (
"upstream content-length (with trailing newline) must not bleed into response"
)