Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
]
readme = "README.md"
requires-python = ">=3.8"
dependencies = ["numpy", "websockets"]
dependencies = ["numpy", "websockets>=14.0"]

[tool.setuptools_scm]
version_file = "webgpu/_version.py"
Expand Down
9 changes: 5 additions & 4 deletions webgpu/export/screenshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import sys
import os
import base64
import shutil
import subprocess
import tempfile
import threading
Expand All @@ -39,14 +40,16 @@ def main():
os.environ['DISPLAY'] = f':{disp_num}'
os.environ.pop('WAYLAND_DISPLAY', None)

tmpdir = Path(tempfile.mkdtemp(prefix="webgpu_ss_"))
try:
_run_worker()
_run_worker(tmpdir)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
xvfb_proc.terminate()
xvfb_proc.wait()


def _run_worker():
def _run_worker(tmpdir):
from playwright.sync_api import sync_playwright

ARGS = [
Expand All @@ -69,8 +72,6 @@ def _run_worker():
engine_js += "\nif (typeof window !== 'undefined') { window.RenderEngine = RenderEngine; }\n"

# Start HTTP server for serving pages to Chrome
tmpdir = Path(tempfile.mkdtemp(prefix="webgpu_ss_"))

class Quiet(SimpleHTTPRequestHandler):
def __init__(self, *a, **kw):
super().__init__(*a, directory=str(tmpdir), **kw)
Expand Down
66 changes: 0 additions & 66 deletions webgpu/link/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,71 +388,6 @@ def _get_obj(self, data):
obj = obj[data["key"]]
return obj

async def _on_message_async(self, message: str | memoryview | bytes):
data, buffers = _unpack_message(message)
obj = None
try:
msg_type = data.get("type", None)
request_id = data.get("request_id", None)

response = None

match msg_type:
case "response":
event, key = self._requests[request_id]
self._requests[request_id] = self._load_data(data.get("value", None), buffers)
if key and data.get("cache", False):
self._cache_add(key, data.get("value", None))

if isinstance(event, asyncio.Future):
event.set_result(self._requests[request_id])
else:
event.set()
return

case "call":
func = obj = self._get_obj(data)
args = self._load_data(data["args"], buffers)
response = func(*args)
try:
response = await response
except TypeError:
pass
except Exception as e:
print("error in call", type(e), str(e))

case "get":
response = obj = self._get_obj(data)

case "get_keys":
response = []

case "set":
prop = data.pop("prop", None)
key = data.pop("key", None)
obj = self._get_obj(data)
if prop is not None:
obj.__setattr__(prop, data["value"])
elif key is not None:
obj[key] = self._load_data(data["value"], buffers)

case "release_batch":
for id_ in data["ids"]:
self._objects.pop(id_, None)

case _:
print("unknown message type", msg_type)

if request_id is not None:
self._send_response(request_id, response)
except Exception as e:
import sys
import traceback

print("error in on_message", data, obj, type(e), str(e), file=sys.stderr)
if not isinstance(e, str):
traceback.print_exception(*sys.exc_info(), file=sys.stderr)

def _on_message(self, message: str):
data, buffers = _unpack_message(message)
try:
Expand Down Expand Up @@ -682,7 +617,6 @@ async def handle_callbacks():
print("error in callback", type(e), str(e))

try:
self._callback_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._callback_loop)
self._callback_task = self._callback_loop.create_task(handle_callbacks())
try:
Expand Down
33 changes: 32 additions & 1 deletion webgpu/link/link.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
/* eslint-disable */

const MAX_MESSAGE_SIZE = 100 * 1024 * 1024;

function serializeEvent(event) {
try {
event.preventDefault();
Expand Down Expand Up @@ -403,11 +405,40 @@ class CrossLink {
const prefixLen = 4 + jsonMsg.byteLength;
const size = 4 + jsonMsg.byteLength + offset;
var msg = new Uint8Array(size);
msg.set(new Uint32Array([jsonMsg.byteLength]), 0);
new DataView(msg.buffer).setUint32(0, jsonMsg.byteLength, true);
msg.set(jsonMsg, 4);

for (var bufferIndex = 0; bufferIndex < buffers.length; bufferIndex++)
msg.set(buffers[bufferIndex], prefixLen + buffer_offsets[bufferIndex]);
this._sendFrame(msg.buffer, request_id);
}
}

_sendFrame(frame, parent_request_id) {
const total = frame.byteLength;
if (total <= MAX_MESSAGE_SIZE) {
this.connection.send(frame);
return;
}
const n_chunks = Math.ceil(total / MAX_MESSAGE_SIZE);
for (let i = 0; i < n_chunks; i++) {
const offset = i * MAX_MESSAGE_SIZE;
const chunk = new Uint8Array(frame, offset, Math.min(MAX_MESSAGE_SIZE, total - offset));
const meta = {
type: 'chunk',
parent_request_id,
chunk_id: i,
n_chunks,
offset,
size: chunk.byteLength,
total_size: total,
buffer_offsets: [0, chunk.byteLength],
};
const json = new TextEncoder().encode(JSON.stringify(meta));
const msg = new Uint8Array(4 + json.byteLength + chunk.byteLength);
new DataView(msg.buffer).setUint32(0, json.byteLength, true);
msg.set(json, 4);
msg.set(chunk, 4 + json.byteLength);
this.connection.send(msg.buffer);
}
}
Expand Down
66 changes: 49 additions & 17 deletions webgpu/link/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from websockets.http11 import Response
from websockets.datastructures import Headers

from .base import LinkBaseAsync
from .base import LinkBaseAsync, _unpack_message


class WebsocketLinkBase(LinkBaseAsync):
Expand All @@ -35,7 +35,6 @@ def __init__(self):
self._event_is_connected = threading.Event()
self._event_is_running = threading.Event()
self._start_handling_messages = threading.Event()
self._send_loop = asyncio.new_event_loop()

self._websocket_thread = threading.Thread(target=self._connect, daemon=True)
self._websocket_thread.start()
Expand Down Expand Up @@ -65,6 +64,7 @@ def __init__(self):
self._port = 8700
self._auth_token = secrets.token_urlsafe(32)
self._executor = ThreadPoolExecutor(max_workers=8)
self._chunk_buffers = {}
self._stop = None
super().__init__()

Expand All @@ -80,22 +80,50 @@ def _check_auth(self, connection, request):
"""Reject WebSocket connections that don't carry a valid token."""
params = parse_qs(urlparse(request.path).query)
tokens = params.get("token", [])
if not tokens or tokens[0] != self._auth_token:
if not tokens or not secrets.compare_digest(tokens[0], self._auth_token):
return Response(403, "Forbidden", Headers())
return None

@staticmethod
def _is_response(message):
"""Quick check if a message is a response (cheap, avoids full deserialization)."""
if isinstance(message, (memoryview, bytes)):
# Binary message: JSON metadata starts at byte 4
try:
def _message_type(message):
"""Return the top-level message type, parsing only the JSON header
(not buffer payloads). Returns None on malformed input."""
try:
if isinstance(message, (memoryview, bytes)):
prefix_size = 4 + int.from_bytes(message[:4], byteorder="little")
header = message[4:prefix_size]
return b'"type":"response"' in bytes(header) or b'"type": "response"' in bytes(header)
except Exception:
return False
return '"type":"response"' in message or '"type": "response"' in message
header = json.loads(bytes(message[4:prefix_size]).decode("utf-8"))
else:
header = json.loads(message)
return header.get("type") if isinstance(header, dict) else None
except Exception:
return None

def _is_response(self, message):
return self._message_type(message) == "response"

def _is_chunk(self, message):
return isinstance(message, (memoryview, bytes)) and self._message_type(message) == "chunk"

def _reassemble_chunk(self, message):
data, buffers = _unpack_message(message)
pid = data["parent_request_id"]
buf = self._chunk_buffers.get(pid)
if buf is None:
buf = bytearray(data["total_size"])
self._chunk_buffers[pid] = buf
chunk = buffers[0]
offset = data["offset"]
buf[offset : offset + len(chunk)] = chunk
if data["chunk_id"] + 1 == data["n_chunks"]:
del self._chunk_buffers[pid]
return bytes(buf)
return None

def _dispatch(self, message):
if self._is_response(message):
self._on_message(message)
else:
self._executor.submit(self._on_message, message)

async def _websocket_handler(self, websocket, path=""):
if self._connection is not None:
Expand All @@ -107,13 +135,17 @@ async def _websocket_handler(self, websocket, path=""):
async for message in websocket:
# Handle responses inline to avoid deadlock: if all executor
# threads are blocked waiting for JS responses, queued response
# messages would never be processed.
if self._is_response(message):
self._on_message(message)
# messages would never be processed. Chunks are reassembled
# inline (single-threaded, ordered) then dispatched.
if self._is_chunk(message):
full = self._reassemble_chunk(message)
if full is not None:
self._dispatch(full)
else:
self._executor.submit(self._on_message, message)
self._dispatch(message)
finally:
self._connection = None
self._chunk_buffers.clear()

def _connect(self):
async def start_websocket():
Expand Down
5 changes: 3 additions & 2 deletions webgpu/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,9 @@ def _on_camera_changed(self):

def _on_resize(self):
"""Called on canvas resize. Update camera uniforms (aspect ratio) and re-render."""
self._select_buffer_valid = False
self.options.update_buffers()
with self._render_mutex:
self._select_buffer_valid = False
self.options.update_buffers()
if self._js_engine is not None:
try:
self._js_engine.handleResize()
Expand Down
Loading