diff --git a/examples/basic/index.d.ts b/examples/basic/index.d.ts index bf845b5..c9bbb22 100644 --- a/examples/basic/index.d.ts +++ b/examples/basic/index.d.ts @@ -1,6 +1,56 @@ /* auto-generated by zig-addon */ /* eslint-disable */ +export interface FibProgress { + current: number + total: number +} + +export interface FileReadSummary { + path: string + bytes: number + text: string +} + +export interface ParallelReadInput { + first_path: string + second_path: string + preview_bytes: number +} + +export interface ParallelReadSummary { + first_bytes: number + second_bytes: number + total_bytes: number + preview: string +} + +export interface AsyncMathInput { + left: number + right: number + scale: number +} + +export interface AsyncMathResult { + sum: number + product: number + scaled_sum: number +} + +export interface CountProgress { + current: number + total: number +} + +export interface AbortSignal { + aborted: boolean + reason?: any + onabort: null | ((this: AbortSignal, event: any) => void) + addEventListener(name: string, listener: (this: AbortSignal, event: any) => any): void + removeEventListener(name: string, listener: (this: AbortSignal, event: any) => any): void + throwIfAborted?(): void +} + export interface FullField { name: string age: number @@ -70,7 +120,17 @@ export declare function hello(name: string): string export declare const text: string export declare function throw_error(): void export declare function fib(n: number): void -export declare function fib_async(n: number): Promise +export declare function fib_async(n: number): Promise +export declare function fib_async_progress(n: number, onEvent?: (event: FibProgress) => void): Promise +export declare function read_file_async(path: string): Promise +export declare function read_file_summary_async(path: string): Promise +export declare function parallel_read_files_async(input: ParallelReadInput): Promise +export declare function async_math_single(input: AsyncMathInput): Promise +export declare function async_void_thread(): Promise +export declare function async_fail_thread(message: string): Promise +export declare function count_async_progress_thread(total: number, onEvent?: (event: CountProgress) => void): Promise +export declare function event_mode_progress_async(total: number, onEvent?: (event: CountProgress) => void): Promise +export declare function abortable_count_async(total: number, signal: AbortSignal): Promise export declare function get_and_return_array(array: Array): Array export declare function get_named_array(array: [number, boolean, string]): [number, boolean, string] export declare function get_arraylist(array: Array): Array diff --git a/examples/basic/src/async.zig b/examples/basic/src/async.zig new file mode 100644 index 0000000..a4325da --- /dev/null +++ b/examples/basic/src/async.zig @@ -0,0 +1,199 @@ +const napi = @import("napi"); +const std = @import("std"); + +fn fibonacci(n: f64) f64 { + if (n <= 1) return n; + return fibonacci(n - 1) + fibonacci(n - 2); +} + +pub const FibProgress = struct { + current: f64, + total: f64, +}; + +pub const CountProgress = struct { + current: u32, + total: u32, +}; + +pub const AsyncMathInput = struct { + left: f64, + right: f64, + scale: f64, +}; + +pub const AsyncMathResult = struct { + sum: f64, + product: f64, + scaled_sum: f64, +}; + +pub const FileReadSummary = struct { + path: []u8, + bytes: usize, + text: []u8, +}; + +pub const ParallelReadInput = struct { + first_path: []u8, + second_path: []u8, + preview_bytes: usize, +}; + +pub const ParallelReadSummary = struct { + first_bytes: usize, + second_bytes: usize, + total_bytes: usize, + preview: []u8, +}; + +fn fibonacci_execute(data: f64) f64 { + return fibonacci(data); +} + +fn fibonacci_execute_with_progress(ctx: napi.AsyncContext(FibProgress), data: f64) !f64 { + try ctx.emit(.{ .current = 0, .total = data }); + const result = fibonacci(data); + try ctx.emit(.{ .current = data, .total = data }); + return result; +} + +fn read_file_execute(ctx: napi.AsyncContext(void), path: []u8) ![]u8 { + return try std.Io.Dir.cwd().readFileAlloc(ctx.io, path, ctx.allocator, .limited(1024 * 1024)); +} + +fn read_file_summary_execute(ctx: napi.AsyncContext(void), path: []u8) !FileReadSummary { + const text = try read_file_execute(ctx, path); + return .{ + .path = path, + .bytes = text.len, + .text = text, + }; +} + +fn async_math_execute(input: AsyncMathInput) AsyncMathResult { + const sum = input.left + input.right; + return .{ + .sum = sum, + .product = input.left * input.right, + .scaled_sum = sum * input.scale, + }; +} + +fn async_void_execute(_: void) void {} + +fn async_fail_execute(message: []u8) !void { + return napi.Error.fromReason(message); +} + +const ReadSlot = struct { + text: []u8 = &.{}, + err: ?anyerror = null, +}; + +fn read_file_slot(ctx: napi.AsyncContext(void), path: []u8, slot: *ReadSlot) std.Io.Cancelable!void { + slot.text = std.Io.Dir.cwd().readFileAlloc(ctx.io, path, ctx.allocator, .limited(1024 * 1024)) catch |err| { + slot.err = err; + return; + }; +} + +fn append_preview(output: []u8, offset: *usize, text: []const u8, limit: usize) void { + const len = @min(text.len, limit); + @memcpy(output[offset.* .. offset.* + len], text[0..len]); + offset.* += len; +} + +fn parallel_read_execute(ctx: napi.AsyncContext(void), input: ParallelReadInput) !ParallelReadSummary { + var first: ReadSlot = .{}; + var second: ReadSlot = .{}; + + try ctx.group.concurrent(ctx.io, read_file_slot, .{ ctx, input.first_path, &first }); + try ctx.group.concurrent(ctx.io, read_file_slot, .{ ctx, input.second_path, &second }); + try ctx.awaitGroup(); + + if (first.err) |err| return err; + if (second.err) |err| return err; + + const first_preview_len = @min(first.text.len, input.preview_bytes); + const second_preview_len = @min(second.text.len, input.preview_bytes); + const separator = "\n---\n"; + const preview_len = first_preview_len + separator.len + second_preview_len; + const preview = try ctx.allocator.alloc(u8, preview_len); + var offset: usize = 0; + append_preview(preview, &offset, first.text, input.preview_bytes); + @memcpy(preview[offset .. offset + separator.len], separator); + offset += separator.len; + append_preview(preview, &offset, second.text, input.preview_bytes); + + return .{ + .first_bytes = first.text.len, + .second_bytes = second.text.len, + .total_bytes = first.text.len + second.text.len, + .preview = preview, + }; +} + +fn count_with_progress_execute(ctx: napi.AsyncContext(CountProgress), total: u32) !u32 { + var current: u32 = 0; + while (current <= total) : (current += 1) { + try ctx.emit(.{ .current = current, .total = total }); + } + return total; +} + +fn abortable_count_execute(ctx: napi.AsyncContext(void), total: u32) !u32 { + var current: u32 = 0; + while (current < total) : (current += 1) { + if (current % 1024 == 0) { + try ctx.checkCancelled(); + } + } + try ctx.checkCancelled(); + return total; +} + +pub fn fib_async(n: f64) napi.Async(f64, .thread) { + return napi.Async(f64, .thread).from(n, fibonacci_execute); +} + +pub fn fib_async_progress(n: f64) napi.AsyncWithEvents(f64, FibProgress, .single) { + return napi.AsyncWithEvents(f64, FibProgress, .single).from(n, fibonacci_execute_with_progress); +} + +pub fn read_file_async(path: []u8) napi.Async([]u8, .thread) { + return napi.Async([]u8, .thread).from(path, read_file_execute); +} + +pub fn read_file_summary_async(path: []u8) napi.Async(FileReadSummary, .thread) { + return napi.Async(FileReadSummary, .thread).from(path, read_file_summary_execute); +} + +pub fn parallel_read_files_async(input: ParallelReadInput) napi.Async(ParallelReadSummary, .thread) { + return napi.Async(ParallelReadSummary, .thread).from(input, parallel_read_execute); +} + +pub fn async_math_single(input: AsyncMathInput) napi.Async(AsyncMathResult, .single) { + return napi.Async(AsyncMathResult, .single).from(input, async_math_execute); +} + +pub fn async_void_thread() napi.Async(void, .thread) { + return napi.Async(void, .thread).from({}, async_void_execute); +} + +pub fn async_fail_thread(message: []u8) napi.Async(void, .thread) { + return napi.Async(void, .thread).from(message, async_fail_execute); +} + +pub fn count_async_progress_thread(total: u32) napi.AsyncWithEvents(u32, CountProgress, .thread) { + return napi.AsyncWithEvents(u32, CountProgress, .thread).from(total, count_with_progress_execute); +} + +pub fn event_mode_progress_async(total: u32) napi.AsyncWithEvents(u32, CountProgress, .event) { + return napi.AsyncWithEvents(u32, CountProgress, .event).from(total, count_with_progress_execute); +} + +pub fn abortable_count_async(total: u32, signal: napi.AbortSignal) napi.Async(u32, .thread) { + _ = signal; + return napi.Async(u32, .thread).from(total, abortable_count_execute); +} diff --git a/examples/basic/src/hello.zig b/examples/basic/src/hello.zig index 5eab9cc..7cf6047 100644 --- a/examples/basic/src/hello.zig +++ b/examples/basic/src/hello.zig @@ -4,6 +4,7 @@ const number = @import("number.zig"); const string = @import("string.zig"); const err = @import("err.zig"); const worker = @import("worker.zig"); +const async_examples = @import("async.zig"); const array = @import("array.zig"); const object = @import("object.zig"); const function = @import("function.zig"); @@ -28,7 +29,17 @@ pub const text = string.text; pub const throw_error = err.throw_error; pub const fib = worker.fib; -pub const fib_async = worker.fib_async; +pub const fib_async = async_examples.fib_async; +pub const fib_async_progress = async_examples.fib_async_progress; +pub const read_file_async = async_examples.read_file_async; +pub const read_file_summary_async = async_examples.read_file_summary_async; +pub const parallel_read_files_async = async_examples.parallel_read_files_async; +pub const async_math_single = async_examples.async_math_single; +pub const async_void_thread = async_examples.async_void_thread; +pub const async_fail_thread = async_examples.async_fail_thread; +pub const count_async_progress_thread = async_examples.count_async_progress_thread; +pub const event_mode_progress_async = async_examples.event_mode_progress_async; +pub const abortable_count_async = async_examples.abortable_count_async; pub const get_and_return_array = array.get_and_return_array; pub const get_named_array = array.get_named_array; diff --git a/examples/basic/src/thread_safe_function.zig b/examples/basic/src/thread_safe_function.zig index 1e09dea..59f4075 100644 --- a/examples/basic/src/thread_safe_function.zig +++ b/examples/basic/src/thread_safe_function.zig @@ -13,16 +13,25 @@ fn sleepForFiveSeconds() void { } fn execute_thread_safe_function(tsfn: *napi.ThreadSafeFunction(Args, Return, true, 0)) void { + defer tsfn.release(.Release) catch {}; sleepForFiveSeconds(); - tsfn.Ok(.{ 1, 2 }, .NonBlocking); + tsfn.Ok(.{ 1, 2 }, .NonBlocking) catch {}; } fn execute_thread_safe_function_with_error(tsfn: *napi.ThreadSafeFunction(Args, Return, true, 0)) void { + defer tsfn.release(.Release) catch {}; sleepForFiveSeconds(); - tsfn.Err(napi.Error.withReason("TSFN Error"), .NonBlocking); + tsfn.Err(napi.Error.withReason("TSFN Error"), .NonBlocking) catch {}; } -pub fn call_thread_safe_function(tsfn: *napi.ThreadSafeFunction(Args, Return, true, 0)) void { - _ = std.Thread.spawn(.{}, execute_thread_safe_function, .{tsfn}) catch @panic("Failed to spawn thread"); - _ = std.Thread.spawn(.{}, execute_thread_safe_function_with_error, .{tsfn}) catch @panic("Failed to spawn thread"); +pub fn call_thread_safe_function(tsfn: *napi.ThreadSafeFunction(Args, Return, true, 0)) !void { + try tsfn.acquire(); + const worker = try std.Thread.spawn(.{}, execute_thread_safe_function, .{tsfn}); + worker.detach(); + + try tsfn.acquire(); + const worker_with_error = try std.Thread.spawn(.{}, execute_thread_safe_function_with_error, .{tsfn}); + worker_with_error.detach(); + + try tsfn.release(.Release); } diff --git a/examples/basic/src/worker.zig b/examples/basic/src/worker.zig index 20d74fc..2d52059 100644 --- a/examples/basic/src/worker.zig +++ b/examples/basic/src/worker.zig @@ -6,12 +6,8 @@ fn fibonacci(n: f64) f64 { return fibonacci(n - 1) + fibonacci(n - 2); } -fn fibonacci_execute(_: napi.Env, data: f64) void { - const result = fibonacci(data); - - const allocator = std.heap.page_allocator; - const message = std.fmt.allocPrint(allocator, "Fibonacci result: {d}", .{result}) catch @panic("OOM"); - defer allocator.free(message); +fn fibonacci_execute(data: f64) f64 { + return fibonacci(data); } fn fibonacci_on_complete(_: napi.Env, data: f64) void { @@ -24,9 +20,3 @@ pub fn fib(env: napi.Env, n: f64) void { const worker = napi.Worker(env, .{ .data = n, .Execute = fibonacci_execute, .OnComplete = fibonacci_on_complete }); worker.Queue(); } - -pub fn fib_async(env: napi.Env, n: f64) napi.Promise { - const worker = napi.Worker(env, .{ .data = n, .Execute = fibonacci_execute, .OnComplete = fibonacci_on_complete }); - const promise = worker.AsyncQueue(); - return promise; -} diff --git a/examples/init/index.d.ts b/examples/init/index.d.ts index 9926dcf..0ed5ba8 100644 --- a/examples/init/index.d.ts +++ b/examples/init/index.d.ts @@ -7,7 +7,7 @@ export declare function add(left: number, right: number): number export declare function hello(name: string): string export declare const text: string export declare function fib(n: number): void -export declare function fib_async(n: number): Promise +export declare function fib_async(n: number): Promise export declare function get_and_return_array(array: Array): Array export declare function get_named_array(array: [number, boolean, string]): [number, boolean, string] export declare function get_arraylist(array: Array): Array diff --git a/examples/init/src/hello.zig b/examples/init/src/hello.zig index 77f2fd1..e1cd4e8 100644 --- a/examples/init/src/hello.zig +++ b/examples/init/src/hello.zig @@ -12,12 +12,8 @@ fn fibonacci(n: f64) f64 { return fibonacci(n - 1) + fibonacci(n - 2); } -fn fibonacci_execute(_: napi.Env, data: f64) void { - const result = fibonacci(data); - - const allocator = std.heap.page_allocator; - const message = std.fmt.allocPrint(allocator, "Fibonacci result: {d}", .{result}) catch @panic("OOM"); - defer allocator.free(message); +fn fibonacci_execute(data: f64) f64 { + return fibonacci(data); } fn fibonacci_on_complete(_: napi.Env, data: f64) void { @@ -45,10 +41,8 @@ fn fib(env: napi.Env, n: f64) void { worker.Queue(); } -fn fib_async(env: napi.Env, n: f64) napi.Promise { - const worker = napi.Worker(env, .{ .data = n, .Execute = fibonacci_execute, .OnComplete = fibonacci_on_complete }); - const promise = worker.AsyncQueue(); - return promise; +fn fib_async(n: f64) napi.Async(f64, .thread) { + return napi.Async(f64, .thread).from(n, fibonacci_execute); } fn get_and_return_array(array: []f32) []f32 { diff --git a/src/build/napi-tsgen.zig b/src/build/napi-tsgen.zig index e79d4b7..2d3087a 100644 --- a/src/build/napi-tsgen.zig +++ b/src/build/napi-tsgen.zig @@ -73,6 +73,34 @@ fn isPromiseType(comptime T: type) bool { return T == napi.Promise; } +fn isAsyncDescriptorType(comptime T: type) bool { + switch (@typeInfo(T)) { + .@"struct", .@"enum", .@"union", .@"opaque" => {}, + else => return false, + } + return @hasDecl(T, "is_napi_async_descriptor"); +} + +fn isAbortSignalType(comptime T: type) bool { + switch (@typeInfo(T)) { + .@"struct", .@"enum", .@"union", .@"opaque" => {}, + else => return false, + } + return @hasDecl(T, "is_napi_abort_signal"); +} + +fn asyncResultType(comptime T: type) type { + return T.async_result_type; +} + +fn asyncEventType(comptime T: type) type { + return T.async_event_type; +} + +fn asyncHasEvents(comptime T: type) bool { + return @hasDecl(T, "async_has_events") and T.async_has_events; +} + fn typedArrayName(comptime T: type) ?[]const u8 { switch (@typeInfo(T)) { .@"struct", .@"enum", .@"union", .@"opaque" => {}, @@ -211,6 +239,23 @@ const State = struct { } }; +fn emitAbortSignalDecl(state: *State) !void { + if (state.emitted.contains("AbortSignal")) return; + try state.emitted.put("AbortSignal", {}); + try append(&state.declarations, + \\export interface AbortSignal { + \\ aborted: boolean + \\ reason?: any + \\ onabort: null | ((this: AbortSignal, event: any) => void) + \\ addEventListener(name: string, listener: (this: AbortSignal, event: any) => any): void + \\ removeEventListener(name: string, listener: (this: AbortSignal, event: any) => any): void + \\ throwIfAborted?(): void + \\} + \\ + \\ + ); +} + const FunctionSource = struct { file_path: []const u8, fn_name: []const u8, @@ -939,7 +984,15 @@ fn emitType(state: *State, comptime T: type) ![]const u8 { if (isNumeric(T)) return "number"; if (isStringLike(T)) return "string"; + if (comptime isAbortSignalType(T)) { + try emitAbortSignalDecl(state); + return "AbortSignal"; + } if (isPromiseType(T)) return "Promise"; + if (comptime isAsyncDescriptorType(T)) { + const result_type = try emitType(state, asyncResultType(T)); + return try std.fmt.allocPrint(state.allocator, "Promise<{s}>", .{result_type}); + } if (T == napi.Buffer) return "Buffer"; if (T == napi.ArrayBuffer) return "ArrayBuffer"; @@ -1145,6 +1198,10 @@ fn emitFunctionLikeWithNames(state: *State, comptime T: type, comptime is_tsfn: return try buf.toOwnedSlice(); } +fn emitAsyncEventCallbackType(state: *State, comptime T: type) ![]const u8 { + return try std.fmt.allocPrint(state.allocator, "(event: {s}) => void", .{try emitType(state, asyncEventType(T))}); +} + fn emitMethodSignature(state: *State, writer: *StringBuilder, comptime fn_type: type, comptime name: []const u8, comptime skip_first: bool, param_names: ?[]const []const u8) !void { const info = @typeInfo(fn_type).@"fn"; try appendFmt(writer, "{s}(", .{name}); @@ -1161,6 +1218,11 @@ fn emitMethodSignature(state: *State, writer: *StringBuilder, comptime fn_type: fn emitMethodParams(state: *State, writer: *StringBuilder, comptime fn_type: type, comptime skip_first: bool, param_names: ?[]const []const u8) !void { const info = @typeInfo(fn_type).@"fn"; + const ret = info.return_type.?; + const ret_payload = switch (@typeInfo(ret)) { + .error_union => |eu| eu.payload, + else => ret, + }; var first = true; const total = if (skip_first) info.params.len - 1 else info.params.len; const source_offset: usize = if (param_names) |names| @@ -1176,6 +1238,10 @@ fn emitMethodParams(state: *State, writer: *StringBuilder, comptime fn_type: typ const effective_names = if (param_names) |names| names[source_offset..] else null; try appendFmt(writer, "{s}: {s}", .{ resolvedArgName(effective_names, arg_idx, total), ts_type }); } + if (comptime isAsyncDescriptorType(ret_payload) and asyncHasEvents(ret_payload)) { + if (!first) try append(writer, ", "); + try appendFmt(writer, "onEvent?: {s}", .{try emitAsyncEventCallbackType(state, ret_payload)}); + } try append(writer, ")"); } @@ -1279,6 +1345,10 @@ fn emitExportFunction(state: *State, comptime name: []const u8, comptime fn_valu .error_union => |eu| eu.payload, else => ret, }; + if (comptime isAsyncDescriptorType(ret_payload) and asyncHasEvents(ret_payload)) { + if (!first) try append(&state.exports, ", "); + try appendFmt(&state.exports, "onEvent?: {s}", .{try emitAsyncEventCallbackType(state, ret_payload)}); + } const ret_payload_info = @typeInfo(ret_payload); if (comptime isFunctionType(ret_payload)) { const returned_param_names = try state.source.getReturnedFunctionParamNames(name); @@ -1507,6 +1577,43 @@ fn emitSourceStructType(state: *State, file_path: []const u8, struct_expr: []con return try buf.toOwnedSlice(); } +const ParsedAsyncSourceType = struct { + result_expr: []const u8, + event_expr: ?[]const u8, +}; + +fn parseAsyncSourceTypeExpr(allocator: std.mem.Allocator, type_expr: []const u8) !?ParsedAsyncSourceType { + const trimmed = std.mem.trim(u8, type_expr, " \t\r\n"); + const async_markers = [_][]const u8{ "napi.Async(", "Async(" }; + const async_events_markers = [_][]const u8{ "napi.AsyncWithEvents(", "AsyncWithEvents(" }; + + inline for (async_events_markers) |marker| { + if (std.mem.startsWith(u8, trimmed, marker) and std.mem.endsWith(u8, trimmed, ")")) { + const args = try parseCallArguments(allocator, trimmed, marker.len - 1); + if (args.items.len >= 2) { + return .{ + .result_expr = args.items[0], + .event_expr = args.items[1], + }; + } + } + } + + inline for (async_markers) |marker| { + if (std.mem.startsWith(u8, trimmed, marker) and std.mem.endsWith(u8, trimmed, ")")) { + const args = try parseCallArguments(allocator, trimmed, marker.len - 1); + if (args.items.len >= 1) { + return .{ + .result_expr = args.items[0], + .event_expr = null, + }; + } + } + } + + return null; +} + fn emitSourceTypeExpr(state: *State, file_path: []const u8, type_expr: []const u8, depth: usize) anyerror![]const u8 { if (depth > 8) return "unknown"; @@ -1526,6 +1633,17 @@ fn emitSourceTypeExpr(state: *State, file_path: []const u8, type_expr: []const u if (std.mem.eql(u8, trimmed, "bool")) return "boolean"; if (isSourceNumericType(trimmed)) return "number"; if (isSourceStringType(trimmed)) return "string"; + if (std.mem.eql(u8, trimmed, "AbortSignal") or + std.mem.eql(u8, trimmed, "napi.AbortSignal") or + std.mem.endsWith(u8, trimmed, ".AbortSignal")) + { + try emitAbortSignalDecl(state); + return "AbortSignal"; + } + if (try parseAsyncSourceTypeExpr(state.allocator, trimmed)) |async_source| { + const result_ts = try emitSourceTypeExpr(state, file_path, async_source.result_expr, depth + 1); + return try std.fmt.allocPrint(state.allocator, "Promise<{s}>", .{result_ts}); + } if (std.mem.eql(u8, trimmed, "napi.Promise")) return "Promise"; if (std.mem.eql(u8, trimmed, "napi.Buffer")) return "Buffer"; @@ -1649,6 +1767,13 @@ fn buildInitFunctionDeclaration(state: *State, export_name: []const u8, file_pat }); } + if (try parseAsyncSourceTypeExpr(state.allocator, signature.return_type_expr)) |async_source| { + if (async_source.event_expr) |event_expr| { + if (!first) try append(&buf, ", "); + try appendFmt(&buf, "onEvent?: (event: {s}) => void", .{try emitSourceTypeExpr(state, file_path, event_expr, 0)}); + } + } + try appendFmt(&buf, "): {s}\n", .{try emitSourceTypeExpr(state, file_path, signature.return_type_expr, 0)}); return try buf.toOwnedSlice(); } diff --git a/src/napi.zig b/src/napi.zig index 8684bb1..52dc1c1 100644 --- a/src/napi.zig +++ b/src/napi.zig @@ -6,6 +6,8 @@ const module = @import("./prelude/module.zig"); const worker = @import("./napi/wrapper/worker.zig"); const err = @import("./napi/wrapper/error.zig"); const thread_safe_function = @import("./napi/wrapper/thread_safe_function.zig"); +const async = @import("./napi/async.zig"); +const abort_signal = @import("./napi/abort_signal.zig"); const class = @import("./napi/wrapper/class.zig"); const buffer = @import("./napi/wrapper/buffer.zig"); const arraybuffer = @import("./napi/wrapper/arraybuffer.zig"); @@ -35,6 +37,12 @@ pub const Function = function.Function; pub const CallbackInfo = callback_info.CallbackInfo; pub const Worker = worker.Worker; pub const ThreadSafeFunction = thread_safe_function.ThreadSafeFunction; +pub const ThreadSafeFunctionMode = thread_safe_function.ThreadSafeFunctionMode; +pub const ThreadSafeFunctionReleaseMode = thread_safe_function.ThreadSafeFunctionReleaseMode; +pub const AsyncRuntime = async.RuntimeModel; +pub const CancelToken = async.CancelToken; +pub const AbortSignal = abort_signal.AbortSignal; +pub const resolveRequestedRuntime = async.resolveRequestedRuntime; pub const Class = class.Class; pub const ClassWithoutInit = class.ClassWithoutInit; pub const Buffer = buffer.Buffer; @@ -57,6 +65,15 @@ pub fn FunctionRef(comptime Args: type, comptime Return: type) type { return reference.Reference(function.Function(Args, Return)); } pub const ObjectRef = reference.Reference(value.Object); +pub fn AsyncContext(comptime Event: type) type { + return async.AsyncContext(Event); +} +pub fn Async(comptime Result: type, comptime runtime: async.RuntimeModel) type { + return async.Async(Result, runtime); +} +pub fn AsyncWithEvents(comptime Result: type, comptime Event: type, comptime runtime: async.RuntimeModel) type { + return async.AsyncWithEvents(Result, Event, runtime); +} pub const NODE_API_MODULE = module.NODE_API_MODULE; pub const NODE_API_MODULE_WITH_INIT = module.NODE_API_MODULE_WITH_INIT; diff --git a/src/napi/abort_signal.zig b/src/napi/abort_signal.zig new file mode 100644 index 0000000..57331cf --- /dev/null +++ b/src/napi/abort_signal.zig @@ -0,0 +1,207 @@ +const std = @import("std"); +const napi = @import("napi-sys").napi_sys; +const Env = @import("./env.zig").Env; +const String = @import("./value/string.zig").String; +const Undefined = @import("./value/undefined.zig").Undefined; +const NapiError = @import("./wrapper/error.zig"); +const GlobalAllocator = @import("./util/allocator.zig"); + +pub const AbortCallback = *const fn (?*anyopaque) void; + +const AbortRegistrationStack = struct { + allocator: std.mem.Allocator, + registrations: std.array_list.Managed(*AbortRegistration), + + fn init(allocator: std.mem.Allocator) AbortRegistrationStack { + return .{ + .allocator = allocator, + .registrations = std.array_list.Managed(*AbortRegistration).init(allocator), + }; + } + + fn deinit(self: *AbortRegistrationStack) void { + self.registrations.deinit(); + } + + fn append(self: *AbortRegistrationStack, registration: *AbortRegistration) !void { + try self.registrations.append(registration); + } + + fn remove(self: *AbortRegistrationStack, registration: *AbortRegistration) void { + for (self.registrations.items, 0..) |item, index| { + if (item == registration) { + _ = self.registrations.swapRemove(index); + return; + } + } + } +}; + +pub const AbortRegistration = struct { + env: napi.napi_env, + signal_ref: napi.napi_ref, + stack: *AbortRegistrationStack, + callback_context: ?*anyopaque, + callback: AbortCallback, + active: bool = true, + + pub fn requestAbort(self: *AbortRegistration) void { + if (!self.active) return; + self.callback(self.callback_context); + } + + pub fn release(self: *AbortRegistration) void { + if (self.active) { + self.stack.remove(self); + self.active = false; + } + if (self.signal_ref != null) { + _ = napi.napi_delete_reference(self.env, self.signal_ref); + self.signal_ref = null; + } + GlobalAllocator.globalAllocator().destroy(self); + } +}; + +pub const AbortSignal = struct { + pub const is_napi_abort_signal = true; + + env: napi.napi_env, + raw: napi.napi_value, + + pub fn from_raw(env: napi.napi_env, raw: napi.napi_value) AbortSignal { + return .{ .env = env, .raw = raw }; + } + + pub fn from_napi_value(env: napi.napi_env, raw: napi.napi_value) AbortSignal { + return from_raw(env, raw); + } + + pub fn isAborted(self: AbortSignal) !bool { + var aborted_value: napi.napi_value = undefined; + const get_status = napi.napi_get_named_property(self.env, self.raw, "aborted", &aborted_value); + if (get_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(get_status)); + } + + var aborted = false; + const bool_status = napi.napi_get_value_bool(self.env, aborted_value, &aborted); + if (bool_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(bool_status)); + } + return aborted; + } + + pub fn bind(self: AbortSignal, callback_context: ?*anyopaque, callback: AbortCallback) !*AbortRegistration { + const stack = try ensureStack(self.env, self.raw); + + const allocator = GlobalAllocator.globalAllocator(); + const registration = try allocator.create(AbortRegistration); + errdefer allocator.destroy(registration); + + var signal_ref: napi.napi_ref = null; + const ref_status = napi.napi_create_reference(self.env, self.raw, 1, &signal_ref); + if (ref_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(ref_status)); + } + errdefer _ = napi.napi_delete_reference(self.env, signal_ref); + + registration.* = .{ + .env = self.env, + .signal_ref = signal_ref, + .stack = stack, + .callback_context = callback_context, + .callback = callback, + }; + try stack.append(registration); + return registration; + } +}; + +fn ensureStack(env: napi.napi_env, signal: napi.napi_value) !*AbortRegistrationStack { + const allocator = GlobalAllocator.globalAllocator(); + + var stack_ptr: ?*anyopaque = null; + const remove_status = napi.napi_remove_wrap(env, signal, &stack_ptr); + const stack: *AbortRegistrationStack = if (remove_status == napi.napi_ok and stack_ptr != null) + @ptrCast(@alignCast(stack_ptr.?)) + else blk: { + const new_stack = try allocator.create(AbortRegistrationStack); + new_stack.* = AbortRegistrationStack.init(allocator); + break :blk new_stack; + }; + + errdefer if (!(remove_status == napi.napi_ok and stack_ptr != null)) { + stack.deinit(); + allocator.destroy(stack); + }; + + var ref: napi.napi_ref = null; + const wrap_status = napi.napi_wrap(env, signal, @ptrCast(stack), finalizeStack, null, &ref); + if (wrap_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(wrap_status)); + } + if (ref != null) { + var ref_count: u32 = 0; + _ = napi.napi_reference_unref(env, ref, &ref_count); + } + + try installOnAbort(env, signal); + return stack; +} + +fn installOnAbort(env: napi.napi_env, signal: napi.napi_value) !void { + var callback: napi.napi_value = undefined; + const create_status = napi.napi_create_function( + env, + "onabort", + "onabort".len, + onAbort, + null, + &callback, + ); + if (create_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(create_status)); + } + + const set_status = napi.napi_set_named_property(env, signal, "onabort", callback); + if (set_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(set_status)); + } +} + +fn onAbort(env: napi.napi_env, info: napi.napi_callback_info) callconv(.c) napi.napi_value { + var this: napi.napi_value = null; + var argc: usize = 0; + const cb_status = napi.napi_get_cb_info(env, info, &argc, null, &this, null); + if (cb_status == napi.napi_ok and this != null) { + var stack_ptr: ?*anyopaque = null; + const unwrap_status = napi.napi_unwrap(env, this, &stack_ptr); + if (unwrap_status == napi.napi_ok and stack_ptr != null) { + const stack: *AbortRegistrationStack = @ptrCast(@alignCast(stack_ptr.?)); + for (stack.registrations.items) |registration| { + registration.requestAbort(); + } + } + } + + return Undefined.New(Env.from_raw(env)).raw; +} + +fn finalizeStack(_: napi.napi_env, finalize_data: ?*anyopaque, _: ?*anyopaque) callconv(.c) void { + const data = finalize_data orelse return; + const stack: *AbortRegistrationStack = @ptrCast(@alignCast(data)); + stack.deinit(); + GlobalAllocator.globalAllocator().destroy(stack); +} + +pub fn abortErrorValue(env: Env) napi.napi_value { + var error_value: napi.napi_value = undefined; + const code = String.New(env, "AbortError").raw; + const message = String.New(env, "AbortError").raw; + const create_status = napi.napi_create_error(env.raw, code, message, &error_value); + std.debug.assert(create_status == napi.napi_ok); + const set_status = napi.napi_set_named_property(env.raw, error_value, "name", code); + std.debug.assert(set_status == napi.napi_ok); + return error_value; +} diff --git a/src/napi/async.zig b/src/napi/async.zig new file mode 100644 index 0000000..383d3cf --- /dev/null +++ b/src/napi/async.zig @@ -0,0 +1,680 @@ +const std = @import("std"); +const napi = @import("napi-sys").napi_sys; +const Env = @import("./env.zig").Env; +const Promise = @import("./value/promise.zig").Promise; +const String = @import("./value/string.zig").String; +const Undefined = @import("./value/undefined.zig").Undefined; +const Napi = @import("./util/napi.zig").Napi; +const NapiError = @import("./wrapper/error.zig"); +const GlobalAllocator = @import("./util/allocator.zig"); +const AbortSignal = @import("./abort_signal.zig").AbortSignal; +const AbortRegistration = @import("./abort_signal.zig").AbortRegistration; + +var threaded_runtime_mutex: std.atomic.Mutex = .unlocked; +var threaded_runtime_initialized = false; +var threaded_runtime: std.Io.Threaded = undefined; + +pub const RuntimeModel = enum { + single, + thread, + event, + + // Backward-compatible spellings kept while examples and downstream users migrate. + serial, + threaded, + evented, +}; + +const EffectiveRuntime = enum { + single, + thread, +}; + +pub const CancelToken = struct { + cancelled: std.atomic.Value(bool) = std.atomic.Value(bool).init(false), + + pub fn cancel(self: *CancelToken) void { + self.cancelled.store(true, .seq_cst); + } + + pub fn isCancelled(self: *const CancelToken) bool { + return self.cancelled.load(.seq_cst); + } + + pub fn check(self: *const CancelToken) !void { + if (self.isCancelled()) return error.Cancelled; + } +}; + +pub fn resolveRequestedRuntime(runtime: RuntimeModel) RuntimeModel { + return switch (runtime) { + .serial => .single, + .threaded => .thread, + .evented => .event, + else => runtime, + }; +} + +fn effectiveRuntime(runtime: RuntimeModel) EffectiveRuntime { + return switch (resolveRequestedRuntime(runtime)) { + .single => .single, + .thread => .thread, + .event => if (std.Io.Evented == void) .thread else .single, + .serial, .threaded, .evented => unreachable, + }; +} + +fn singleIo() std.Io { + return std.Io.Threaded.global_single_threaded.io(); +} + +fn threadedIo() std.Io { + while (!threaded_runtime_mutex.tryLock()) { + std.Thread.yield() catch {}; + } + defer threaded_runtime_mutex.unlock(); + + if (!threaded_runtime_initialized) { + threaded_runtime = std.Io.Threaded.init(GlobalAllocator.globalAllocator(), .{}); + threaded_runtime_initialized = true; + } + + return threaded_runtime.io(); +} + +fn ioForRuntime(effective_runtime: EffectiveRuntime) std.Io { + return switch (effective_runtime) { + .single => singleIo(), + .thread => threadedIo(), + }; +} + +pub fn AsyncContext(comptime Event: type) type { + return struct { + allocator: std.mem.Allocator, + io: std.Io, + group: *std.Io.Group, + runtime: RuntimeModel, + effective_runtime: RuntimeModel, + cancel_token: *const CancelToken, + emitter_ptr: ?*anyopaque, + emit_fn: ?*const fn (?*anyopaque, Event) anyerror!void, + + const Self = @This(); + + pub fn emit(self: Self, event: Event) !void { + if (Event == void) { + @compileError("AsyncContext(void) does not support emit()"); + } + try self.cancel_token.check(); + const emit_fn = self.emit_fn orelse return error.InvalidArg; + const emitter_ptr = self.emitter_ptr orelse return error.InvalidArg; + try emit_fn(emitter_ptr, event); + } + + pub fn isCancelled(self: Self) bool { + return self.cancel_token.isCancelled(); + } + + pub fn checkCancelled(self: Self) !void { + try self.cancel_token.check(); + } + + pub fn awaitGroup(self: Self) !void { + try self.group.await(self.io); + } + + pub fn cancelGroup(self: Self) void { + self.group.cancel(self.io); + } + }; +} + +pub fn mapAnyError(err: anyerror) NapiError.Error { + if (NapiError.last_error) |last_error| return last_error; + + return switch (err) { + error.Canceled, error.Cancelled => NapiError.Error.withReason(@as([]const u8, "AbortError")), + error.Closing => NapiError.Error.withStatus(@as([]const u8, "Closing")), + else => |actual_err| blk: { + const name = @errorName(actual_err); + break :blk NapiError.Error.withStatus(name[0..name.len]); + }, + }; +} + +fn createOptionalCallbackRef(env: napi.napi_env, raw: ?napi.napi_value) !?napi.napi_ref { + const value = raw orelse return null; + + var value_type: napi.napi_valuetype = undefined; + const typeof_status = napi.napi_typeof(env, value, &value_type); + if (typeof_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(typeof_status)); + } + + switch (value_type) { + napi.napi_undefined, napi.napi_null => return null, + napi.napi_function => {}, + else => return error.InvalidArg, + } + + var ref: napi.napi_ref = null; + const ref_status = napi.napi_create_reference(env, value, 1, &ref); + if (ref_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(ref_status)); + } + return ref; +} + +fn releaseCallbackRef(env: napi.napi_env, ref: *?napi.napi_ref) void { + if (ref.*) |actual_ref| { + _ = napi.napi_delete_reference(env, actual_ref); + ref.* = null; + } +} + +fn validateTaskRunSignature(comptime Input: type, comptime Result: type, comptime Event: type, comptime RunFn: anytype) void { + const run_type = @TypeOf(RunFn); + const info = @typeInfo(run_type); + if (info != .@"fn") { + @compileError("Async task runner must be a function"); + } + + const params = info.@"fn".params; + if (params.len != 1 and params.len != 2) { + @compileError("Async task runner must accept (input) or (AsyncContext(Event), input)"); + } + + if (params.len == 1) { + if (params[0].type.? != Input) { + @compileError("Async task runner input type mismatch"); + } + } else { + if (params[0].type.? != AsyncContext(Event)) { + @compileError("Async task runner context type must be napi.AsyncContext(Event)"); + } + if (params[1].type.? != Input) { + @compileError("Async task runner input type mismatch"); + } + } + + const return_type = info.@"fn".return_type.?; + switch (@typeInfo(return_type)) { + .error_union => |eu| { + if (eu.payload != Result) { + @compileError("Async task runner return type mismatch"); + } + }, + else => { + if (return_type != Result) { + @compileError("Async task runner return type mismatch"); + } + }, + } +} + +pub fn Async(comptime Result: type, comptime runtime: RuntimeModel) type { + return AsyncTaskDescriptor(Result, void, runtime); +} + +pub fn AsyncWithEvents(comptime Result: type, comptime Event: type, comptime runtime: RuntimeModel) type { + return AsyncTaskDescriptor(Result, Event, runtime); +} + +fn AsyncTaskDescriptor(comptime Result: type, comptime Event: type, comptime runtime: RuntimeModel) type { + return struct { + pub const is_napi_async_descriptor = true; + pub const async_result_type = Result; + pub const async_event_type = Event; + pub const async_runtime_model = runtime; + pub const async_has_events = Event != void; + + base: *AsyncTaskDescriptorBase, + + const Self = @This(); + + pub fn from(input: anytype, comptime run_fn: anytype) Self { + const Input = @TypeOf(input); + validateTaskRunSignature(Input, Result, Event, run_fn); + + const allocator = GlobalAllocator.globalAllocator(); + const Impl = AsyncTaskDescriptorImpl(Input, Result, Event, runtime, run_fn); + var impl = allocator.create(Impl) catch @panic("OOM"); + impl.* = .{ + .base = .{ + .allocator = allocator, + .schedule_fn = Impl.schedule, + .destroy_fn = Impl.destroy, + }, + .input = input, + }; + return .{ .base = &impl.base }; + } + + pub fn schedule(self: *Self, env: Env) !Promise { + return try self.scheduleWithListenerAndSignal(env, null, null); + } + + pub fn scheduleWithListener(self: *Self, env: Env, listener: ?napi.napi_value) !Promise { + return try self.scheduleWithListenerAndSignal(env, listener, null); + } + + pub fn scheduleWithSignal(self: *Self, env: Env, signal: ?AbortSignal) !Promise { + return try self.scheduleWithListenerAndSignal(env, null, signal); + } + + pub fn scheduleWithListenerAndSignal(self: *Self, env: Env, listener: ?napi.napi_value, signal: ?AbortSignal) !Promise { + const base = self.base; + return try base.schedule_fn(base, env.raw, listener, signal); + } + + pub fn deinit(self: *Self) void { + self.base.destroy_fn(self.base); + } + }; +} + +const AsyncTaskDescriptorBase = struct { + allocator: std.mem.Allocator, + schedule_fn: *const fn (*AsyncTaskDescriptorBase, napi.napi_env, ?napi.napi_value, ?AbortSignal) anyerror!Promise, + destroy_fn: *const fn (*AsyncTaskDescriptorBase) void, +}; + +fn AsyncTaskDescriptorImpl( + comptime Input: type, + comptime Result: type, + comptime Event: type, + comptime runtime: RuntimeModel, + comptime run_fn: anytype, +) type { + return struct { + base: AsyncTaskDescriptorBase, + input: Input, + + const Self = @This(); + + fn schedule(base: *AsyncTaskDescriptorBase, env_raw: napi.napi_env, listener: ?napi.napi_value, signal: ?AbortSignal) !Promise { + const self: *Self = @alignCast(@fieldParentPtr("base", base)); + const operation = try AsyncTaskOperation(Input, Result, Event, runtime, run_fn).create(Env.from_raw(env_raw), self.input, listener, signal); + defer base.destroy_fn(base); + return try operation.submit(); + } + + fn destroy(base: *AsyncTaskDescriptorBase) void { + const self: *Self = @alignCast(@fieldParentPtr("base", base)); + self.base.allocator.destroy(self); + } + }; +} + +fn AsyncTaskOperation( + comptime Input: type, + comptime Result: type, + comptime Event: type, + comptime runtime: RuntimeModel, + comptime run_fn: anytype, +) type { + return struct { + allocator: std.mem.Allocator, + env: napi.napi_env, + promise: Promise, + input: Input, + result: Result = if (Result == void) {} else undefined, + err: ?NapiError.Error = null, + listener_ref: ?napi.napi_ref = null, + abort_registration: ?*AbortRegistration = null, + cancel_token: CancelToken = .{}, + controller_thread: ?std.Thread = null, + future: ?std.Io.Future(void) = null, + tsfn_raw: napi.napi_threadsafe_function = null, + state_mutex: std.Io.Mutex = .init, + state_cond: std.Io.Condition = .init, + task_done: bool = false, + cancel_requested: bool = false, + cancel_dispatched: bool = false, + closed: bool = false, + + const Self = @This(); + const Context = AsyncContext(Event); + const run_info = @typeInfo(@TypeOf(run_fn)).@"fn"; + const DispatchKind = enum { event, completion }; + const DispatchData = struct { + kind: DispatchKind, + payload: ?*Event = null, + }; + + fn create(env: Env, input: Input, listener: ?napi.napi_value, signal: ?AbortSignal) !*Self { + const allocator = GlobalAllocator.globalAllocator(); + const self = try allocator.create(Self); + errdefer allocator.destroy(self); + + self.* = .{ + .allocator = allocator, + .env = env.raw, + .promise = Promise.New(env), + .input = input, + .listener_ref = if (Event == void) null else try createOptionalCallbackRef(env.raw, listener), + }; + errdefer releaseCallbackRef(env.raw, &self.listener_ref); + + if (signal) |abort_signal| { + self.abort_registration = try abort_signal.bind(@ptrCast(self), requestAbortFromSignal); + } + + return self; + } + + fn submit(self: *Self) !Promise { + errdefer self.destroy(self.env); + + const promise = self.promise; + if (self.abort_registration != null and self.isAbortRequestedFromSignal()) { + self.cancel_token.cancel(); + self.cancel_requested = true; + self.promise.RejectAbortError() catch {}; + self.destroy(self.env); + return promise; + } + + switch (effectiveRuntime(runtime)) { + .single => self.runSingle(), + .thread => { + try self.initThreadDispatcher(); + self.future = std.Io.concurrent(threadedIo(), runTask, .{self}) catch |err| { + self.err = mapAnyError(err); + self.dispatchCompletion(self.env); + return promise; + }; + self.controller_thread = try std.Thread.spawn(.{}, controllerThreadMain, .{self}); + }, + } + return promise; + } + + fn controllerThreadMain(self: *Self) void { + const io = threadedIo(); + const should_cancel = self.waitForTaskDoneOrAbort(); + if (self.future) |*future| { + if (should_cancel) { + future.cancel(io); + self.cancel_dispatched = true; + } else { + future.await(io); + } + } + self.queueCompletion() catch {}; + } + + fn runSingle(self: *Self) void { + const io = singleIo(); + var future = std.Io.async(io, runTask, .{self}); + future.await(io); + self.dispatchCompletion(self.env); + } + + fn runTask(self: *Self) void { + defer self.markTaskDone(); + + const task_runtime = effectiveRuntime(runtime); + const io = ioForRuntime(task_runtime); + var group: std.Io.Group = .init; + defer group.cancel(io); + + const context = Context{ + .allocator = self.allocator, + .io = io, + .group = &group, + .runtime = runtime, + .effective_runtime = switch (task_runtime) { + .single => .single, + .thread => .thread, + }, + .cancel_token = &self.cancel_token, + .emitter_ptr = if (Event == void) null else @ptrCast(self), + .emit_fn = if (Event == void) null else emitFromContext, + }; + + NapiError.clearLastError(); + self.execute(context) catch |err| { + self.err = mapAnyError(err); + return; + }; + group.await(io) catch |err| { + self.err = mapAnyError(err); + }; + } + + fn isAbortRequestedFromSignal(self: *Self) bool { + if (self.abort_registration) |registration| { + var signal_value: napi.napi_value = undefined; + const ref_status = napi.napi_get_reference_value(self.env, registration.signal_ref, &signal_value); + if (ref_status != napi.napi_ok or signal_value == null) return false; + return AbortSignal.from_raw(self.env, signal_value).isAborted() catch false; + } + return false; + } + + fn requestAbortFromSignal(ptr: ?*anyopaque) void { + const self: *Self = @ptrCast(@alignCast(ptr)); + self.requestAbort(); + } + + fn requestAbort(self: *Self) void { + const io = threadedIo(); + self.cancel_token.cancel(); + self.state_mutex.lockUncancelable(io); + defer self.state_mutex.unlock(io); + if (self.task_done) return; + self.cancel_requested = true; + self.state_cond.signal(io); + } + + fn markTaskDone(self: *Self) void { + const io = threadedIo(); + self.state_mutex.lockUncancelable(io); + defer self.state_mutex.unlock(io); + self.task_done = true; + self.state_cond.signal(io); + } + + fn waitForTaskDoneOrAbort(self: *Self) bool { + const io = threadedIo(); + self.state_mutex.lockUncancelable(io); + defer self.state_mutex.unlock(io); + + while (!self.task_done and !self.cancel_requested) { + self.state_cond.waitUncancelable(io, &self.state_mutex); + } + return self.cancel_requested and !self.task_done; + } + + fn execute(self: *Self, context: Context) !void { + if (run_info.params.len == 1) { + if (@typeInfo(run_info.return_type.?) == .error_union) { + if (Result == void) { + try run_fn(self.input); + } else { + self.result = try run_fn(self.input); + } + } else { + if (Result == void) { + _ = run_fn(self.input); + } else { + self.result = run_fn(self.input); + } + } + } else { + if (@typeInfo(run_info.return_type.?) == .error_union) { + if (Result == void) { + try run_fn(context, self.input); + } else { + self.result = try run_fn(context, self.input); + } + } else { + if (Result == void) { + _ = run_fn(context, self.input); + } else { + self.result = run_fn(context, self.input); + } + } + } + } + + fn emitFromContext(ptr: ?*anyopaque, event: Event) anyerror!void { + const self: *Self = @ptrCast(@alignCast(ptr)); + try self.cancel_token.check(); + + switch (effectiveRuntime(runtime)) { + .single => self.dispatchEvent(self.env, event), + .thread => { + const payload = try self.allocator.create(Event); + payload.* = event; + errdefer self.allocator.destroy(payload); + + const data = try self.allocator.create(DispatchData); + data.* = .{ .kind = .event, .payload = payload }; + errdefer self.allocator.destroy(data); + + const status = napi.napi_call_threadsafe_function(self.tsfn_raw, @ptrCast(data), napi.napi_tsfn_nonblocking); + if (status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(status)); + } + }, + } + } + + fn dispatchEvent(self: *Self, env_raw: napi.napi_env, event: Event) void { + if (Event == void or self.listener_ref == null) return; + + var callback: napi.napi_value = undefined; + const get_ref_status = napi.napi_get_reference_value(env_raw, self.listener_ref.?, &callback); + if (get_ref_status != napi.napi_ok) return; + + const event_value = Napi.to_napi_value(env_raw, event, null) catch return; + const undefined_value = Undefined.New(Env.from_raw(env_raw)); + const argv = [1]napi.napi_value{event_value}; + var ignored: napi.napi_value = undefined; + _ = napi.napi_call_function(env_raw, undefined_value.raw, callback, argv.len, &argv, &ignored); + } + + fn queueCompletion(self: *Self) !void { + const data = try self.allocator.create(DispatchData); + data.* = .{ .kind = .completion }; + errdefer self.allocator.destroy(data); + + const status = napi.napi_call_threadsafe_function(self.tsfn_raw, @ptrCast(data), napi.napi_tsfn_nonblocking); + if (status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(status)); + } + } + + fn dispatchCompletion(self: *Self, env_raw: napi.napi_env) void { + if (self.controller_thread) |thread| { + thread.join(); + self.controller_thread = null; + } + + if (self.cancel_dispatched or self.cancel_requested) { + self.promise.RejectAbortError() catch {}; + } else if (self.err) |err| { + self.promise.Reject(err) catch {}; + } else if (Result == void) { + self.promise.Resolve({}) catch {}; + } else { + self.promise.Resolve(self.result) catch {}; + } + + self.destroy(env_raw); + } + + fn initThreadDispatcher(self: *Self) !void { + const resource_name = String.New(Env.from_raw(self.env), "ZigAsyncTask"); + var dispatcher_fn: napi.napi_value = undefined; + const create_fn_status = napi.napi_create_function( + self.env, + "zigAsyncTask", + "zigAsyncTask".len, + dispatcherNoop, + null, + &dispatcher_fn, + ); + if (create_fn_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(create_fn_status)); + } + + var tsfn_raw: napi.napi_threadsafe_function = null; + const create_status = napi.napi_create_threadsafe_function( + self.env, + dispatcher_fn, + null, + resource_name.raw, + 0, + 1, + null, + dispatcherFinalize, + @ptrCast(self), + dispatcherCallJs, + &tsfn_raw, + ); + if (create_status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(create_status)); + } + self.tsfn_raw = tsfn_raw; + } + + fn dispatcherNoop(inner_env: napi.napi_env, _: napi.napi_callback_info) callconv(.c) napi.napi_value { + return Undefined.New(Env.from_raw(inner_env)).raw; + } + + fn dispatcherFinalize(_: napi.napi_env, _: ?*anyopaque, _: ?*anyopaque) callconv(.c) void {} + + fn dispatcherCallJs(inner_env: napi.napi_env, _: napi.napi_value, context: ?*anyopaque, raw_data: ?*anyopaque) callconv(.c) void { + const self: *Self = @ptrCast(@alignCast(context)); + const data: *DispatchData = @ptrCast(@alignCast(raw_data)); + const allocator = self.allocator; + defer allocator.destroy(data); + + switch (data.kind) { + .event => { + if (Event != void and data.payload != null) { + const payload = data.payload.?; + defer allocator.destroy(payload); + self.dispatchEvent(inner_env, payload.*); + } + }, + .completion => self.dispatchCompletion(inner_env), + } + } + + fn destroy(self: *Self, env_raw: napi.napi_env) void { + if (self.closed) return; + self.closed = true; + + releaseCallbackRef(env_raw, &self.listener_ref); + if (self.abort_registration) |registration| { + registration.release(); + self.abort_registration = null; + } + if (self.tsfn_raw != null) { + _ = napi.napi_release_threadsafe_function(self.tsfn_raw, napi.napi_tsfn_release); + self.tsfn_raw = null; + } + self.allocator.destroy(self); + } + }; +} + +test "Async descriptor exposes runtime metadata" { + const Task = Async(u32, .thread); + try std.testing.expect(Task.is_napi_async_descriptor); + try std.testing.expect(Task.async_result_type == u32); + try std.testing.expect(Task.async_event_type == void); + try std.testing.expect(Task.async_runtime_model == .thread); +} + +test "AsyncWithEvents descriptor marks callback support" { + const Event = struct { current: u32 }; + const Task = AsyncWithEvents(u32, Event, .single); + try std.testing.expect(Task.async_has_events); + try std.testing.expect(Task.async_event_type == Event); +} diff --git a/src/napi/util/helper.zig b/src/napi/util/helper.zig index cd0f5c4..9e7427e 100644 --- a/src/napi/util/helper.zig +++ b/src/napi/util/helper.zig @@ -84,6 +84,22 @@ pub fn isThreadSafeFunction(comptime T: type) bool { return false; } +pub fn isAsyncDescriptor(comptime T: type) bool { + switch (@typeInfo(T)) { + .@"struct", .@"enum", .@"union", .@"opaque" => {}, + else => return false, + } + return @hasDecl(T, "is_napi_async_descriptor"); +} + +pub fn isAbortSignal(comptime T: type) bool { + switch (@typeInfo(T)) { + .@"struct", .@"enum", .@"union", .@"opaque" => {}, + else => return false, + } + return @hasDecl(T, "is_napi_abort_signal"); +} + pub fn isTypedArray(comptime T: type) bool { return @hasDecl(T, "is_napi_typedarray"); } diff --git a/src/napi/util/napi.zig b/src/napi/util/napi.zig index a5e6379..aefcb70 100644 --- a/src/napi/util/napi.zig +++ b/src/napi/util/napi.zig @@ -10,6 +10,7 @@ const class = @import("../wrapper/class.zig"); const Buffer = @import("../wrapper/buffer.zig").Buffer; const ArrayBuffer = @import("../wrapper/arraybuffer.zig").ArrayBuffer; const DataView = @import("../wrapper/dataview.zig").DataView; +const AbortSignal = @import("../abort_signal.zig").AbortSignal; fn napiTypeOf(env: napi.napi_env, raw: napi.napi_value) napi.napi_valuetype { var value_type: napi.napi_valuetype = undefined; @@ -161,6 +162,7 @@ fn valueMatchesType(env: napi.napi_env, raw: napi.napi_value, comptime T: type) break :blk valueMatchesType(env, raw, infos.optional.child); }, .@"struct" => blk: { + if (comptime helper.isAbortSignal(T)) break :blk napiTypeOf(env, raw) == napi.napi_object; if (comptime helper.isNapiFunction(T)) break :blk napiTypeOf(env, raw) == napi.napi_function; if (comptime helper.isTypedArray(T)) break :blk isTypedArrayValue(env, raw); if (comptime helper.isDataView(T)) break :blk isDataViewValue(env, raw); @@ -242,6 +244,9 @@ pub const Napi = struct { return NapiValue.Array.from_napi_value(env, raw, T); }, .@"struct" => { + if (comptime helper.isAbortSignal(T)) { + return AbortSignal.from_napi_value(env, raw); + } if (comptime helper.isNapiFunction(T)) { const fn_infos = @typeInfo(T); comptime var args_type = void; @@ -389,6 +394,9 @@ pub const Napi = struct { if (comptime helper.isNapiFunction(value_type)) { return value.raw; } + if (comptime helper.isAsyncDescriptor(value_type)) { + @compileError("Async descriptors can only be returned from exported functions"); + } if (comptime helper.isTypedArray(value_type)) { return value.raw; } diff --git a/src/napi/value/function.zig b/src/napi/value/function.zig index 20b77cf..0e88706 100644 --- a/src/napi/value/function.zig +++ b/src/napi/value/function.zig @@ -7,6 +7,8 @@ const NapiError = @import("../wrapper/error.zig"); const Undefined = @import("./undefined.zig").Undefined; const GlobalAllocator = @import("../util/allocator.zig"); const Reference = @import("../wrapper/reference.zig").Reference; +const helper = @import("../util/helper.zig"); +const AbortSignal = @import("../abort_signal.zig").AbortSignal; pub fn Function(comptime Args: type, comptime Return: type) type { const ArgsInfos = @typeInfo(Args); @@ -37,7 +39,18 @@ pub fn Function(comptime Args: type, comptime Return: type) type { const FnImpl = struct { fn inner_fn(inner_env: napi.napi_env, info: napi.napi_callback_info) callconv(.c) napi.napi_value { const undefined_value = Undefined.New(Env.from_raw(inner_env)); - var init_argc: usize = params.len; + const return_info = infos.@"fn".return_type.?; + const return_payload = switch (@typeInfo(return_info)) { + .error_union => |eu| eu.payload, + else => return_info, + }; + const async_returns_descriptor = comptime helper.isAsyncDescriptor(return_payload); + const has_async_events = comptime async_returns_descriptor and return_payload.async_has_events; + const has_env = comptime params.len > 0 and params[0].type.? == Env; + const env_index = if (has_env) 1 else 0; + const expected_argc = params.len - env_index + if (has_async_events) 1 else 0; + + var init_argc: usize = expected_argc; const allocator = GlobalAllocator.globalAllocator(); const args_raw = allocator.alloc(napi.napi_value, init_argc) catch @panic("OOM"); @@ -48,19 +61,20 @@ pub fn Function(comptime Args: type, comptime Return: type) type { return NapiError.checkNapiStatus(inner_env, NapiError.Status.New(cb_status)); } - const has_env = comptime params.len > 0 and params[0].type.? == Env; - const env_index = if (has_env) 1 else 0; - var napi_params: std.meta.ArgsTuple(value_type) = undefined; if (comptime has_env) { napi_params[0] = Env.from_raw(inner_env); } + var abort_signal: ?AbortSignal = null; inline for (params[env_index..], env_index..) |param_index, i| { if (comptime @typeInfo(param_index.type.?) == .@"union") { NapiError.clearLastError(); } napi_params[i] = Napi.from_napi_value(inner_env, args_raw[i - env_index], param_index.type.?); + if (comptime helper.isAbortSignal(param_index.type.?)) { + abort_signal = napi_params[i]; + } if (comptime @typeInfo(param_index.type.?) == .@"union") { if (NapiError.last_error) |last_err| { last_err.throwInto(Env.from_raw(inner_env)); @@ -69,7 +83,10 @@ pub fn Function(comptime Args: type, comptime Return: type) type { } } - const return_info = infos.@"fn".return_type.?; + const event_listener = if (has_async_events and init_argc > params.len - env_index) + args_raw[init_argc - 1] + else + null; if (@typeInfo(return_info) == .error_union) { const ret = @call(.auto, value, napi_params) catch { @@ -78,6 +95,16 @@ pub fn Function(comptime Args: type, comptime Return: type) type { } return undefined_value.raw; }; + if (comptime async_returns_descriptor) { + var task = ret; + const promise = task.scheduleWithListenerAndSignal(Env.from_raw(inner_env), event_listener, abort_signal) catch { + if (NapiError.last_error) |last_err| { + last_err.throwInto(Env.from_raw(inner_env)); + } + return undefined_value.raw; + }; + return promise.raw; + } const n_value = Napi.to_napi_value(inner_env, ret, null) catch { if (NapiError.last_error) |last_err| { last_err.throwInto(Env.from_raw(inner_env)); @@ -87,6 +114,16 @@ pub fn Function(comptime Args: type, comptime Return: type) type { return n_value; } else { const ret = @call(.auto, value, napi_params); + if (comptime async_returns_descriptor) { + var task = ret; + const promise = task.scheduleWithListenerAndSignal(Env.from_raw(inner_env), event_listener, abort_signal) catch { + if (NapiError.last_error) |last_err| { + last_err.throwInto(Env.from_raw(inner_env)); + } + return undefined_value.raw; + }; + return promise.raw; + } const n_value = Napi.to_napi_value(inner_env, ret, null) catch { if (NapiError.last_error) |last_err| { last_err.throwInto(Env.from_raw(inner_env)); diff --git a/src/napi/value/promise.zig b/src/napi/value/promise.zig index 4c50b99..33ba277 100644 --- a/src/napi/value/promise.zig +++ b/src/napi/value/promise.zig @@ -4,6 +4,7 @@ const Env = @import("../env.zig").Env; const Napi = @import("../util/napi.zig").Napi; const NapiValue = @import("../value.zig").NapiValue; const NapiError = @import("../wrapper/error.zig"); +const AbortSignal = @import("../abort_signal.zig"); pub const PromiseStatus = enum { Pending, @@ -62,4 +63,13 @@ pub const Promise = struct { } self.status = .Rejected; } + + pub fn RejectAbortError(self: *Self) !void { + const napi_value = AbortSignal.abortErrorValue(Env.from_raw(self.env)); + const s = napi.napi_reject_deferred(self.env, self.deferred, napi_value); + if (s != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(s)); + } + self.status = .Rejected; + } }; diff --git a/src/napi/wrapper/thread_safe_function.zig b/src/napi/wrapper/thread_safe_function.zig index e658ff4..68cd47b 100644 --- a/src/napi/wrapper/thread_safe_function.zig +++ b/src/napi/wrapper/thread_safe_function.zig @@ -13,6 +13,7 @@ pub const ThreadSafeFunctionMode = enum { Blocking, const Self = @This(); + pub fn to_raw(self: Self) napi.napi_threadsafe_function_call_mode { return switch (self) { .NonBlocking => napi.napi_tsfn_nonblocking, @@ -21,6 +22,20 @@ pub const ThreadSafeFunctionMode = enum { } }; +pub const ThreadSafeFunctionReleaseMode = enum { + Release, + Abort, + + const Self = @This(); + + pub fn to_raw(self: Self) napi.napi_threadsafe_function_release_mode { + return switch (self) { + .Release => napi.napi_tsfn_release, + .Abort => napi.napi_tsfn_abort, + }; + } +}; + pub const ThreadSafeFunctionCallVariant = enum { Direct, WithCallback, @@ -41,6 +56,8 @@ pub fn ThreadSafeFunction(comptime Args: type, comptime Return: type, comptime T allocator: std.mem.Allocator, args: Args, return_type: Return, + closed: bool, + aborted: bool, comptime thread_safe_function_call_variant: bool = ThreadSafeFunctionCalleeHandled, comptime max_queue_size: usize = MaxQueueSize, @@ -50,6 +67,7 @@ pub fn ThreadSafeFunction(comptime Args: type, comptime Return: type, comptime T const ThreadSafe = struct { fn finalize(_: napi.napi_env, data: ?*anyopaque, _: ?*anyopaque) callconv(.c) void { const self: *Self = @ptrCast(@alignCast(data)); + self.closed = true; self.deinit(); } @@ -70,10 +88,8 @@ pub fn ThreadSafeFunction(comptime Args: type, comptime Return: type, comptime T if (self.thread_safe_function_call_variant) { if (args.err) |param| { argv[0] = param.to_napi_error(Env.from_raw(inner_env)); - // if err, return immediately var ret: napi.napi_value = undefined; _ = napi.napi_call_function(inner_env, undefined_value.raw, js_callback, args_len + call_variant, argv.ptr, &ret); - // Free the call data allocator.destroy(param); allocator.destroy(args); return; @@ -90,14 +106,11 @@ pub fn ThreadSafeFunction(comptime Args: type, comptime Return: type, comptime T } else { argv[call_variant] = Napi.to_napi_value(inner_env, actual_args.*, null) catch null; } - // Free the args data allocator.destroy(actual_args); } var ret: napi.napi_value = undefined; _ = napi.napi_call_function(inner_env, undefined_value.raw, js_callback, args_len + call_variant, argv.ptr, &ret); - - // Free the call data allocator.destroy(args); } }; @@ -105,11 +118,36 @@ pub fn ThreadSafeFunction(comptime Args: type, comptime Return: type, comptime T const allocator = GlobalAllocator.globalAllocator(); var self = allocator.create(Self) catch @panic("OOM"); - self.* = Self{ .env = env, .raw = raw, .allocator = allocator, .args = undefined, .return_type = undefined, .tsfn_raw = undefined }; + self.* = Self{ + .env = env, + .raw = raw, + .allocator = allocator, + .args = undefined, + .return_type = undefined, + .tsfn_raw = null, + .closed = false, + .aborted = false, + }; - var tsfn_raw: napi.napi_threadsafe_function = undefined; + var tsfn_raw: napi.napi_threadsafe_function = null; const resource = String.New(Env.from_raw(env), "ThreadSafeFunction"); - _ = napi.napi_create_threadsafe_function(env, raw, null, resource.raw, 0, 1, @ptrCast(self), null, @ptrCast(self), ThreadSafe.cb, &tsfn_raw); + const create_status = napi.napi_create_threadsafe_function( + env, + raw, + null, + resource.raw, + self.max_queue_size, + 1, + @ptrCast(self), + ThreadSafe.finalize, + @ptrCast(self), + ThreadSafe.cb, + &tsfn_raw, + ); + if (create_status != napi.napi_ok) { + allocator.destroy(self); + @panic("Failed to create ThreadSafeFunction"); + } self.tsfn_raw = tsfn_raw; @@ -120,24 +158,81 @@ pub fn ThreadSafeFunction(comptime Args: type, comptime Return: type, comptime T self.allocator.destroy(self); } - pub fn Ok(self: *const Self, args: Args, mode: ThreadSafeFunctionMode) void { + fn freeCallData(self: *const Self, data: *CallData(Args)) void { + if (data.args) |actual_args| { + self.allocator.destroy(actual_args); + } + if (data.err) |actual_err| { + self.allocator.destroy(actual_err); + } + self.allocator.destroy(data); + } + + fn callThreadSafeFunction(self: *const Self, data: *CallData(Args), mode: ThreadSafeFunctionMode) !void { + const status = napi.napi_call_threadsafe_function(self.tsfn_raw, @ptrCast(data), mode.to_raw()); + if (status != napi.napi_ok) { + self.freeCallData(data); + return NapiError.Error.fromStatus(NapiError.Status.New(status)); + } + } + + pub fn acquire(self: *const Self) !void { + const status = napi.napi_acquire_threadsafe_function(self.tsfn_raw); + if (status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(status)); + } + } + + pub fn release(self: *const Self, mode: ThreadSafeFunctionReleaseMode) !void { + const status = napi.napi_release_threadsafe_function(self.tsfn_raw, mode.to_raw()); + if (status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(status)); + } + } + + pub fn abort(self: *Self) !void { + if (self.aborted) return; + try self.release(.Abort); + self.aborted = true; + } + + pub fn ref(self: *const Self) !void { + const status = napi.napi_ref_threadsafe_function(self.env, self.tsfn_raw); + if (status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(status)); + } + } + + pub fn unref(self: *const Self) !void { + const status = napi.napi_unref_threadsafe_function(self.env, self.tsfn_raw); + if (status != napi.napi_ok) { + return NapiError.Error.fromStatus(NapiError.Status.New(status)); + } + } + + pub fn Ok(self: *const Self, args: Args, mode: ThreadSafeFunctionMode) !void { const args_data = self.allocator.create(Args) catch @panic("OOM"); args_data.* = args; const data = self.allocator.create(CallData(Args)) catch @panic("OOM"); data.* = CallData(Args){ .args = args_data, .err = null }; - _ = napi.napi_call_threadsafe_function(self.tsfn_raw, @ptrCast(data), mode.to_raw()); + try self.callThreadSafeFunction(data, mode); } - pub fn Err(self: *const Self, err: NapiError.Error, mode: ThreadSafeFunctionMode) void { - const e = self.allocator.create(NapiError.Error) catch @panic("OOM"); - e.* = err; + pub fn Err(self: *const Self, err: NapiError.Error, mode: ThreadSafeFunctionMode) !void { + const actual_err = self.allocator.create(NapiError.Error) catch @panic("OOM"); + actual_err.* = err; const data = self.allocator.create(CallData(Args)) catch @panic("OOM"); - data.* = CallData(Args){ .args = null, .err = e }; + data.* = CallData(Args){ .args = null, .err = actual_err }; - _ = napi.napi_call_threadsafe_function(self.tsfn_raw, @ptrCast(data), mode.to_raw()); + try self.callThreadSafeFunction(data, mode); } }; } + +test "ThreadSafeFunction release modes map to napi values" { + try std.testing.expect(ThreadSafeFunctionReleaseMode.Release.to_raw() == napi.napi_tsfn_release); + try std.testing.expect(ThreadSafeFunctionReleaseMode.Abort.to_raw() == napi.napi_tsfn_abort); +} diff --git a/src/napi/wrapper/worker.zig b/src/napi/wrapper/worker.zig index e541e10..bcf7a73 100644 --- a/src/napi/wrapper/worker.zig +++ b/src/napi/wrapper/worker.zig @@ -2,10 +2,7 @@ const std = @import("std"); const napi = @import("napi-sys").napi_sys; const napi_env = @import("../env.zig"); const String = @import("../value/string.zig").String; -const napi_status = @import("./status.zig"); -const Value = @import("../value.zig").Value; const Promise = @import("../value/promise.zig").Promise; -const Napi = @import("../util/napi.zig").Napi; const NapiError = @import("./error.zig"); const GlobalAllocator = @import("../util/allocator.zig"); @@ -17,42 +14,41 @@ const WorkerStatus = enum { }; pub fn WorkerContext(comptime T: type) type { - const hasData = comptime @hasField(T, "data"); - const hasExecute = comptime @hasField(T, "Execute"); - const hasOnComplete = comptime @hasField(T, "OnComplete"); + const has_data = comptime @hasField(T, "data"); + const has_execute = comptime @hasField(T, "Execute"); + const has_on_complete = comptime @hasField(T, "OnComplete"); - if (!hasData) { + if (!has_data) { @compileError("Worker must init with data field"); } - - if (!hasExecute) { + if (!has_execute) { @compileError("Worker must init with Execute field"); } - - const type_info = @typeInfo(T); - if (type_info != .@"struct") { - @compileError("T must be a struct type"); + if (@typeInfo(T) != .@"struct") { + @compileError("Worker init data must be a struct"); } const DataType = @TypeOf(@as(T, undefined).data); const ExecuteFn = @TypeOf(@as(T, undefined).Execute); - const ExecuteFnInfos = @typeInfo(ExecuteFn); - const OnComplete = if (hasOnComplete) @TypeOf(@as(T, undefined).OnComplete) else void; - const OnCompleteInfos = @typeInfo(OnComplete); - - const ExecuteResult = ExecuteFnInfos.@"fn".return_type.?; + const ExecuteInfo = @typeInfo(ExecuteFn); + if (ExecuteInfo != .@"fn") { + @compileError("Execute must be a function"); + } - const ExecuteResultType = switch (@typeInfo(ExecuteResult)) { + const ExecuteReturn = ExecuteInfo.@"fn".return_type.?; + const ExecutePayload = switch (@typeInfo(ExecuteReturn)) { .error_union => |eu| eu.payload, - else => ExecuteResult, + else => ExecuteReturn, }; - if (ExecuteFnInfos != .@"fn") { - @compileError("Execute must be a function"); - } + comptime validateExecuteSignature(DataType, ExecuteFn); - if (hasOnComplete and OnCompleteInfos != .@"fn") { - @compileError("OnComplete must be a function"); + if (has_on_complete) { + const OnComplete = @TypeOf(@as(T, undefined).OnComplete); + if (@typeInfo(OnComplete) != .@"fn") { + @compileError("OnComplete must be a function"); + } + comptime validateOnCompleteSignature(DataType, OnComplete); } return struct { @@ -60,154 +56,53 @@ pub fn WorkerContext(comptime T: type) type { env: napi.napi_env, raw: napi.napi_async_work, allocator: std.mem.Allocator, - result: ExecuteResultType, - err: ?NapiError.Error, - status: WorkerStatus, - promise: ?*Promise, + result: ExecutePayload = if (ExecutePayload == void) {} else undefined, + err: ?NapiError.Error = null, + status: WorkerStatus = .Pending, + promise: ?*Promise = null, const Self = @This(); pub fn New(env: napi_env.Env, init_data: anytype) *Self { - const Execute = struct { - fn inner_execute(inner_env: napi.napi_env, data: ?*anyopaque) callconv(.c) void { - const inner_self: *Self = @ptrCast(@alignCast(data)); - const params = ExecuteFnInfos.@"fn".params; - const return_type = @typeInfo(ExecuteFnInfos.@"fn".return_type.?); - - switch (params.len) { - 1 => { - if (params[0].type.? != DataType) { - @compileError("Execute's first parameter must be " ++ @typeName(DataType)); - } else { - if (return_type == .error_union) { - inner_self.result = inner_self.data.Execute(inner_self.data.data) catch { - inner_self.status = .Rejected; - // transfer the last error from the worker thread to the main thread - inner_self.err = NapiError.last_error; - return; - }; - } else { - inner_self.result = inner_self.data.Execute(inner_self.data.data); - } - inner_self.status = .Resolved; - } - }, - 2 => { - if (params[0].type.? != napi_env.Env) { - @compileError("Execute's first parameter must be napi.Env"); - } else if (params[1].type.? != DataType) { - @compileError("Execute's second parameter must be " ++ @typeName(DataType)); - } else { - if (return_type == .error_union) { - inner_self.result = inner_self.data.Execute(napi_env.Env.from_raw(inner_env), inner_self.data.data) catch { - inner_self.status = .Rejected; - inner_self.err = NapiError.last_error; - return; - }; - } else { - inner_self.result = inner_self.data.Execute(napi_env.Env.from_raw(inner_env), inner_self.data.data); - } - inner_self.status = .Resolved; - } - }, - else => { - @compileError("Execute must have 1 or 2 parameters, but got " ++ std.fmt.comptimePrint("{d}", .{params.len})); - }, - } - } - }; - - const Complete = struct { - fn inner_complete(inner_env: napi.napi_env, _: napi.napi_status, data: ?*anyopaque) callconv(.c) void { - const inner_self: *Self = @ptrCast(@alignCast(data)); - - switch (inner_self.status) { - .Rejected => { - if (inner_self.promise) |promise| { - if (inner_self.err) |err| { - promise.Reject(err) catch { - if (NapiError.last_error) |last_err| { - last_err.throwInto(napi_env.Env.from_raw(inner_env)); - } - return; - }; - } - } - }, - .Resolved => { - if (inner_self.promise) |promise| { - const napi_data = Napi.to_napi_value(inner_env, inner_self.result, null) catch { - if (NapiError.last_error) |last_err| { - last_err.throwInto(napi_env.Env.from_raw(inner_env)); - } - return; - }; - promise.Resolve(napi_data) catch { - if (NapiError.last_error) |last_err| { - last_err.throwInto(napi_env.Env.from_raw(inner_env)); - } - return; - }; - } - }, - else => {}, - } - const hasComplete = comptime @hasField(T, "OnComplete"); - if (hasComplete) { - const params = OnCompleteInfos.@"fn".params; - switch (params.len) { - 1 => { - if (params[0].type.? != DataType) { - @compileError("OnComplete's first parameter must be " ++ @typeName(DataType)); - } else { - inner_self.data.OnComplete(inner_self.data.data); - } - }, - 2 => { - if (params[0].type.? != napi_env.Env) { - @compileError("OnComplete's first parameter must be napi.Env"); - } else if (params[1].type.? != DataType) { - @compileError("OnComplete's second parameter must be napi.Status"); - } else { - inner_self.data.OnComplete(napi_env.Env.from_raw(inner_env), inner_self.data.data); - } - }, - else => { - @compileError("OnComplete must have 1 or 2 parameters, but got " ++ std.fmt.comptimePrint("{d}", .{params.len})); - }, - } - } - inner_self.deinit(); - } - }; - const allocator = GlobalAllocator.globalAllocator(); - var self = allocator.create(Self) catch @panic("OOM"); + const self = allocator.create(Self) catch @panic("OOM"); - self.* = Self{ + self.* = .{ .data = init_data, .env = env.raw, - .raw = undefined, + .raw = null, .allocator = allocator, - .result = undefined, - .status = .Pending, - .promise = null, - .err = null, }; const async_resource_name = String.New(env, "AsyncWorkerCallback"); - var result: napi.napi_async_work = undefined; - _ = napi.napi_create_async_work(env.raw, null, async_resource_name.raw, Execute.inner_execute, Complete.inner_complete, @ptrCast(self), &result); + var result: napi.napi_async_work = null; + const status = napi.napi_create_async_work( + env.raw, + null, + async_resource_name.raw, + execute, + complete, + @ptrCast(self), + &result, + ); + if (status != napi.napi_ok) { + allocator.destroy(self); + @panic("Failed to create async worker"); + } self.raw = result; - return self; } pub fn deinit(self: *Self) void { + if (self.raw != null) { + _ = napi.napi_delete_async_work(self.env, self.raw); + self.raw = null; + } if (self.promise) |promise| { self.allocator.destroy(promise); + self.promise = null; } self.allocator.destroy(self); } @@ -227,9 +122,144 @@ pub fn WorkerContext(comptime T: type) type { pub fn Cancel(self: *Self) void { _ = napi.napi_cancel_async_work(self.env, self.raw); } + + fn execute(inner_env: napi.napi_env, data: ?*anyopaque) callconv(.c) void { + const self: *Self = @ptrCast(@alignCast(data)); + NapiError.clearLastError(); + self.run(inner_env) catch { + self.status = .Rejected; + self.err = NapiError.last_error orelse NapiError.Error.withStatus("GenericFailure"); + return; + }; + self.status = .Resolved; + } + + fn complete(inner_env: napi.napi_env, status: napi.napi_status, data: ?*anyopaque) callconv(.c) void { + const self: *Self = @ptrCast(@alignCast(data)); + defer self.deinit(); + + if (status == napi.napi_cancelled) { + self.status = .Cancelled; + } + + switch (self.status) { + .Rejected => { + if (self.promise) |promise| { + if (self.err) |err| { + promise.Reject(err) catch { + if (NapiError.last_error) |last_err| { + last_err.throwInto(napi_env.Env.from_raw(inner_env)); + } + }; + } + } else if (self.err) |err| { + err.throwInto(napi_env.Env.from_raw(inner_env)); + } + }, + .Resolved => { + if (self.promise) |promise| { + if (ExecutePayload == void) { + promise.Resolve({}) catch {}; + } else { + promise.Resolve(self.result) catch { + if (NapiError.last_error) |last_err| { + last_err.throwInto(napi_env.Env.from_raw(inner_env)); + } + }; + } + } + }, + else => {}, + } + + if (has_on_complete) { + callOnComplete(self.data, napi_env.Env.from_raw(inner_env)); + } + } + + fn run(self: *Self, inner_env: napi.napi_env) !void { + const execute_fn = self.data.Execute; + if (@typeInfo(ExecuteReturn) == .error_union) { + if (ExecutePayload == void) { + if (ExecuteInfo.@"fn".params.len == 1) { + try execute_fn(self.data.data); + } else { + try execute_fn(napi_env.Env.from_raw(inner_env), self.data.data); + } + } else { + self.result = if (ExecuteInfo.@"fn".params.len == 1) + try execute_fn(self.data.data) + else + try execute_fn(napi_env.Env.from_raw(inner_env), self.data.data); + } + } else { + if (ExecutePayload == void) { + if (ExecuteInfo.@"fn".params.len == 1) { + _ = execute_fn(self.data.data); + } else { + _ = execute_fn(napi_env.Env.from_raw(inner_env), self.data.data); + } + } else { + self.result = if (ExecuteInfo.@"fn".params.len == 1) + execute_fn(self.data.data) + else + execute_fn(napi_env.Env.from_raw(inner_env), self.data.data); + } + } + } }; } +fn validateExecuteSignature(comptime DataType: type, comptime ExecuteFn: type) void { + const info = @typeInfo(ExecuteFn).@"fn"; + if (info.params.len != 1 and info.params.len != 2) { + @compileError("Worker Execute must accept (data) or (napi.Env, data)"); + } + + if (info.params.len == 1) { + if (info.params[0].type.? != DataType) { + @compileError("Worker Execute data type mismatch"); + } + } else { + if (info.params[0].type.? != napi_env.Env) { + @compileError("Worker Execute first parameter must be napi.Env"); + } + if (info.params[1].type.? != DataType) { + @compileError("Worker Execute data type mismatch"); + } + } +} + +fn validateOnCompleteSignature(comptime DataType: type, comptime OnCompleteFn: type) void { + const info = @typeInfo(OnCompleteFn).@"fn"; + if (info.params.len != 1 and info.params.len != 2) { + @compileError("Worker OnComplete must accept (data) or (napi.Env, data)"); + } + + if (info.params.len == 1) { + if (info.params[0].type.? != DataType) { + @compileError("Worker OnComplete data type mismatch"); + } + } else { + if (info.params[0].type.? != napi_env.Env) { + @compileError("Worker OnComplete first parameter must be napi.Env"); + } + if (info.params[1].type.? != DataType) { + @compileError("Worker OnComplete data type mismatch"); + } + } +} + +fn callOnComplete(data: anytype, env: napi_env.Env) void { + const OnCompleteFn = @TypeOf(data.OnComplete); + const info = @typeInfo(OnCompleteFn).@"fn"; + if (info.params.len == 1) { + data.OnComplete(data.data); + } else { + data.OnComplete(env, data.data); + } +} + pub fn Worker(env: napi_env.Env, data: anytype) *WorkerContext(@TypeOf(data)) { return WorkerContext(@TypeOf(data)).New(env, data); }