diff --git a/CMakeLists.txt b/CMakeLists.txt index 047ea43b7..77f2fa034 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,6 +15,7 @@ option(LOVR_ENABLE_HEADSET "Enable the headset module" ON) option(LOVR_ENABLE_MATH "Enable the math module" ON) option(LOVR_ENABLE_PHYSICS "Enable the physics module" ON) option(LOVR_ENABLE_SYSTEM "Enable the system module" ON) +option(LOVR_ENABLE_TASK "Enable the task module" ON) option(LOVR_ENABLE_THREAD "Enable the thread module" ON) option(LOVR_ENABLE_TIMER "Enable the timer module" ON) @@ -540,6 +541,15 @@ else() target_compile_definitions(lovr PRIVATE LOVR_DISABLE_SYSTEM) endif() +if(LOVR_ENABLE_TASK) + target_sources(lovr PRIVATE + src/modules/task/task.c + src/api/l_task.c + ) +else() + target_compile_definitions(lovr PRIVATE LOVR_DISABLE_TASK) +endif() + if(LOVR_ENABLE_THREAD) target_sources(lovr PRIVATE src/core/job.c diff --git a/etc/boot.lua b/etc/boot.lua index eac373b77..2dc717057 100644 --- a/etc/boot.lua +++ b/etc/boot.lua @@ -15,6 +15,7 @@ local conf = { math = true, physics = true, system = true, + task = true, thread = true, timer = true }, @@ -201,6 +202,11 @@ function lovr.run() dt = lovr.headset.update() if not lovr.headset.isActive() then lovr.simulate(dt) end end + if lovr.task then + for task in lovr.task.poll() do + lovr.taskready(task) + end + end if lovr.update then lovr.update(dt) end if lovr.audio then lovr.audio.update(dt) end if lovr.graphics then @@ -403,6 +409,10 @@ function lovr.threaderror(thread, err) error('Thread error\n\n' .. err, 0) end +function lovr.taskready(task) + assert(lovr.task.resume(task)) +end + function lovr.filechanged(path, action, oldpath) if not path:match('^%.') then lovr.event.restart() diff --git a/src/api/api.c b/src/api/api.c index 516bd2879..04538b768 100644 --- a/src/api/api.c +++ b/src/api/api.c @@ -30,6 +30,7 @@ LOVR_EXPORT int luaopen_lovr_headset(lua_State* L); LOVR_EXPORT int luaopen_lovr_math(lua_State* L); LOVR_EXPORT int luaopen_lovr_physics(lua_State* L); LOVR_EXPORT int luaopen_lovr_system(lua_State* L); +LOVR_EXPORT int luaopen_lovr_task(lua_State* L); LOVR_EXPORT int luaopen_lovr_thread(lua_State* L); LOVR_EXPORT int luaopen_lovr_timer(lua_State* L); @@ -55,27 +56,6 @@ static void luax_destructor(lua_State* L, void* userdata) { } #endif -static int luax_release(lua_State* L) { - Object* object = lua_touserdata(L, 1); - - if (!object) { - return 0; - } - - // Remove from userdata cache - lua_getfield(L, LUA_REGISTRYINDEX, "_lovrobjects"); - lua_pushlightuserdata(L, object->pointer); - lua_pushnil(L); - lua_rawset(L, -3); - lua_pop(L, 1); - - // Release - lovrRelease(object->pointer, lovrTypeInfo[object->type].destructor); - object->pointer = NULL; - - return 0; -} - void luax_preload(lua_State* L) { static const luaL_Reg lovrModules[] = { { "lovr", luaopen_lovr }, @@ -106,6 +86,9 @@ void luax_preload(lua_State* L) { #ifndef LOVR_DISABLE_SYSTEM { "lovr.system", luaopen_lovr_system }, #endif +#ifndef LOVR_DISABLE_TASK + { "lovr.task", luaopen_lovr_task }, +#endif #ifndef LOVR_DISABLE_THREAD { "lovr.thread", luaopen_lovr_thread }, #endif @@ -124,6 +107,27 @@ void luax_preload(lua_State* L) { lua_pop(L, 2); } +static int luax_release(lua_State* L) { + Object* object = lua_touserdata(L, 1); + + if (!object) { + return 0; + } + + // Remove from userdata cache + lua_getfield(L, LUA_REGISTRYINDEX, "_lovrobjects"); + lua_pushlightuserdata(L, object->pointer); + lua_pushnil(L); + lua_rawset(L, -3); + lua_pop(L, 1); + + // Release + lovrRelease(object->pointer, lovrTypeInfo[object->type].destructor); + object->pointer = NULL; + + return 0; +} + void _luax_registertype(lua_State* L, int type, const char* name, void (*destructor)(void*), const luaL_Reg* functions) { lovrTypeInfo[type] = (TypeInfo) { name, destructor }; @@ -144,7 +148,7 @@ void _luax_registertype(lua_State* L, int type, const char* name, void (*destruc lua_pushcfunction(L, luax_release); lua_setfield(L, -2, "__gc"); - // m.__close = gc + // m.__close = luax_release lua_pushcfunction(L, luax_release); lua_setfield(L, -2, "__close"); #endif @@ -153,11 +157,6 @@ void _luax_registertype(lua_State* L, int type, const char* name, void (*destruc lua_pushcfunction(L, luax_tostring); lua_setfield(L, -2, "__tostring"); - // Register methods - if (functions) { - luax_register(L, functions); - } - // :release method lua_pushcfunction(L, luax_release); lua_setfield(L, -2, "release"); @@ -166,6 +165,11 @@ void _luax_registertype(lua_State* L, int type, const char* name, void (*destruc lua_pushcfunction(L, luax_type); lua_setfield(L, -2, "type"); + // Register methods + if (functions) { + luax_register(L, functions); + } + // Pop metatable lua_pop(L, 1); } @@ -305,6 +309,33 @@ void luax_registerloader(lua_State* L, lua_CFunction loader, int index) { lua_pop(L, 1); } +static int luax_wrapasync(lua_State* L) { + if (luax_getthreaddata(L)) { + int n = lua_gettop(L); + lua_pushvalue(L, lua_upvalueindex(1)); + lua_insert(L, 1); + return luax_callthread(L, n); + } else { + lua_CFunction function = lua_tocfunction(L, lua_upvalueindex(1)); + return function(L); + } +} + +void _luax_registerasync(lua_State* L, const char* name, const char** methods) { + luaL_getmetatable(L, name); + + if (lua_istable(L, -1)) { + while (*methods) { + const char* method = *methods++; + lua_getfield(L, -1, method); + lua_pushcclosure(L, luax_wrapasync, 1); + lua_setfield(L, -2, method); + } + } + + lua_pop(L, 1); +} + int luax_resume(lua_State* T, int n) { #if LUA_VERSION_NUM >= 504 int results; @@ -433,6 +464,52 @@ void luax_pushstash(lua_State* L, const char* name) { } } +void* luax_getthreaddata(lua_State* L) { +#ifdef LOVR_USE_LUAU + return lua_getthreaddata(L); +#elif LUA_VERSION_NUM >= 503 + return *(void**) lua_getextraspace(L); +#else + lua_getfield(L, LUA_REGISTRYINDEX, "_lovrthreaddata"); + if (lua_isnil(L, -1)) return lua_pop(L, 1), NULL; + lua_pushthread(L); + lua_rawget(L, -2); + void* data = lua_touserdata(L, -1); + lua_pop(L, 2); + return data; +#endif +} + +void luax_setthreaddata(lua_State* L, void* data) { +#ifdef LOVR_USE_LUAU + lua_setthreaddata(L, data); +#elif LUA_VERSION_NUM >= 503 + *(void**) lua_getextraspace(L) = data; +#else + lua_getfield(L, LUA_REGISTRYINDEX, "_lovrthreaddata"); + if (lua_isnil(L, -1)) { + lua_pop(L, 1); + lua_newtable(L); + + lua_newtable(L); + lua_pushliteral(L, "k"); + lua_setfield(L, -2, "__mode"); + lua_setmetatable(L, -2); + + lua_pushvalue(L, -1); + lua_setfield(L, LUA_REGISTRYINDEX, "_lovrthreaddata"); + } + lua_pushthread(L); + if (data) { + lua_pushlightuserdata(L, data); + } else { + lua_pushnil(L); + } + lua_rawset(L, -3); + lua_pop(L, 1); +#endif +} + void luax_setmainthread(lua_State *L) { #if LUA_VERSION_NUM < 502 lua_pushthread(L); diff --git a/src/api/api.h b/src/api/api.h index abcfaf63f..ac3c9271e 100644 --- a/src/api/api.h +++ b/src/api/api.h @@ -104,6 +104,7 @@ typedef struct { #endif #define luax_registertype(L, T) _luax_registertype(L, T_ ## T, #T, lovr ## T ## Destroy, lovr ## T) +#define luax_registerasync(L, T) _luax_registerasync(L, #T, lovr ## T ## Async) #define luax_totype(L, i, T) (T*) _luax_totype(L, i, T_ ## T) #define luax_checktype(L, i, T) (T*) _luax_checktype(L, i, T_ ## T) #define luax_pushtype(L, T, o) _luax_pushtype(L, T_ ## T, o) @@ -123,6 +124,7 @@ int luax_typeerror(lua_State* L, int index, const char* expected); void _luax_pushtype(lua_State* L, int type, void* object); int _luax_checkenum(lua_State* L, int index, const StringEntry* map, const char* fallback, const char* label); void luax_registerloader(lua_State* L, int (*loader)(lua_State* L), int index); +void _luax_registerasync(lua_State* L, const char* type, const char** methods); int luax_resume(lua_State* T, int n); int luax_loadbufferx(lua_State* L, const char* buffer, size_t size, const char* name, const char* mode); void luax_vthrow(void* L, const char* format, va_list args); @@ -133,6 +135,8 @@ int luax_pushsuccess(lua_State* L, bool success); void luax_pushconf(lua_State* L); int luax_setconf(lua_State* L); void luax_pushstash(lua_State* L, const char* name); +void* luax_getthreaddata(lua_State* L); +void luax_setthreaddata(lua_State* L, void* data); void luax_setmainthread(lua_State* L); void luax_atexit(lua_State* L, void (*finalizer)(void)); void luax_close(lua_State* L); @@ -201,3 +205,14 @@ struct Shape* luax_newconvexshape(lua_State* L, int index); struct Shape* luax_newmeshshape(lua_State* L, int index); struct Shape* luax_newterrainshape(lua_State* L, int index); #endif + +#ifndef LOVR_DISABLE_TASK +#include "task/task.h" +Task* luax_gettask(lua_State* L); +int luax_yieldpoll(lua_State* L, fn_task* poll, fn_task* block, fn_continuation* continuation, void* context); +int luax_yieldjob(lua_State* L, fn_task* fn, fn_continuation* continuation, void* context); +#endif + +#ifndef LOVR_DISABLE_THREAD +int luax_callthread(lua_State* L, int n); +#endif diff --git a/src/api/l_data.c b/src/api/l_data.c index 0ae3dfa0e..80c5fb7d5 100644 --- a/src/api/l_data.c +++ b/src/api/l_data.c @@ -111,6 +111,20 @@ static int l_lovrDataNewBlobView(lua_State* L) { return 1; } +static int luax_pushimage(lua_State* L, void* context) { + luax_pushtype(L, Image, context); + lovrRelease(context, lovrImageDestroy); + return 1; +} + +static bool luax_loadimage(void** context) { + Blob* blob = *context; + Image* image = lovrImageCreateFromFile(blob); + lovrRelease(blob, lovrBlobDestroy); + *context = image; + return !!image; +} + static int l_lovrDataNewImage(lua_State* L) { Image* image = NULL; if (lua_type(L, 1) == LUA_TNUMBER) { @@ -139,9 +153,7 @@ static int l_lovrDataNewImage(lua_State* L) { memcpy(lovrImageGetLayerData(image, 0, 0), lovrImageGetLayerData(source, 0, 0), lovrImageGetLayerSize(image, 0)); } else { Blob* blob = luax_readblob(L, 1, "Texture"); - image = lovrImageCreateFromFile(blob); - lovrRelease(blob, lovrBlobDestroy); - luax_assert(L, image); + return luax_yieldjob(L, luax_loadimage, luax_pushimage, blob); } } diff --git a/src/api/l_graphics_buffer.c b/src/api/l_graphics_buffer.c index ff358a16b..f45a7c922 100644 --- a/src/api/l_graphics_buffer.c +++ b/src/api/l_graphics_buffer.c @@ -532,22 +532,44 @@ static int l_lovrBufferNewReadback(lua_State* L) { return 1; } +static bool luax_pollreadback(void** context) { + return lovrReadbackPoll(*context); +} + +static bool luax_waitreadback(void** context) { + return lovrReadbackWait(*context); +} + +static int luax_pushreadbackdata(lua_State* L, void* context) { + DataField* format; + uint32_t count; + void* data = lovrReadbackGetData(context, &format, &count); + lovrRelease(context, lovrReadbackDestroy); + luax_assert(L, data && format); + return luax_pushbufferdata(L, format, count, data); +} + static int l_lovrBufferGetData(lua_State* L) { Buffer* buffer = luax_checktype(L, 1, Buffer); const DataField* format = lovrBufferGetInfo(buffer)->format; luax_check(L, format, "Buffer:getData requires the Buffer to have a format"); + + uint32_t offset, extent; if (format->length > 0) { uint32_t index = luax_optu32(L, 2, 1) - 1; luax_check(L, index < format->length, "Buffer:getData index exceeds the Buffer's length"); uint32_t count = luax_optu32(L, 3, format->length - index); - void* data = lovrBufferGetData(buffer, index * format->stride, count * format->stride); - luax_assert(L, data); - return luax_pushbufferdata(L, format, count, data); + offset = index * format->stride; + extent = count * format->stride; } else { - void* data = lovrBufferGetData(buffer, 0, format->stride); - luax_assert(L, data); - return luax_pushbufferdata(L, format, 0, data); + offset = 0; + extent = format->stride; } + + Readback* readback = lovrReadbackCreateBuffer(buffer, offset, extent); + luax_assert(L, readback); + + return luax_yieldpoll(L, luax_pollreadback, luax_waitreadback, luax_pushreadbackdata, readback); } static int l_lovrBufferSetData(lua_State* L) { diff --git a/src/api/l_graphics_mesh.c b/src/api/l_graphics_mesh.c index a7b58f254..d74757b98 100644 --- a/src/api/l_graphics_mesh.c +++ b/src/api/l_graphics_mesh.c @@ -60,12 +60,19 @@ static int l_lovrMeshSetIndexBuffer(lua_State* L) { static int l_lovrMeshGetVertices(lua_State* L) { Mesh* mesh = luax_checktype(L, 1, Mesh); + + if (lovrMeshGetStorage(mesh) == MESH_GPU) { + lua_pushnil(L); + return 1; + } + uint32_t index = luax_optu32(L, 2, 1) - 1; uint32_t count = luax_optu32(L, 3, ~0u); char* data = lovrMeshGetVertices(mesh, index, count); luax_assert(L, data); const DataField* format = lovrMeshGetVertexFormat(mesh); count = count == ~0u ? format->length - index : count; + lua_createtable(L, (int) count, 0); for (uint32_t i = 0; i < count; i++, data += format->stride) { lua_newtable(L); @@ -80,6 +87,7 @@ static int l_lovrMeshGetVertices(lua_State* L) { } lua_rawseti(L, -2, (int) i + 1); } + return 1; } @@ -112,6 +120,11 @@ static int l_lovrMeshSetVertices(lua_State* L) { static int l_lovrMeshGetIndices(lua_State* L) { Mesh* mesh = luax_checktype(L, 1, Mesh); + if (lovrMeshGetStorage(mesh)) { + lua_pushnil(L); + return 1; + } + uint32_t count; DataType type; void* data; diff --git a/src/api/l_graphics_texture.c b/src/api/l_graphics_texture.c index 9d615d0e8..be386e8d7 100644 --- a/src/api/l_graphics_texture.c +++ b/src/api/l_graphics_texture.c @@ -102,6 +102,22 @@ static int l_lovrTextureNewReadback(lua_State* L) { return 1; } +static bool luax_pollreadback(void** context) { + return lovrReadbackPoll(*context); +} + +static bool luax_waitreadback(void** context) { + return lovrReadbackWait(*context); +} + +static int luax_pushreadbackdata(lua_State* L, void* context) { + Image* image = lovrReadbackGetImage(context); + lovrRelease(context, lovrReadbackDestroy); + luax_assert(L, image); + luax_pushtype(L, Image, image); + return 1; +} + static int l_lovrTextureGetPixels(lua_State* L) { Texture* texture = luax_checktype(L, 1, Texture); uint32_t offset[4], extent[3]; @@ -112,11 +128,11 @@ static int l_lovrTextureGetPixels(lua_State* L) { extent[0] = luax_optu32(L, 6, ~0u); extent[1] = luax_optu32(L, 7, ~0u); extent[2] = 1; - Image* image = lovrTextureGetPixels(texture, offset, extent); - luax_assert(L, image); - luax_pushtype(L, Image, image); - lovrRelease(image, lovrImageDestroy); - return 1; + + Readback* readback = lovrReadbackCreateTexture(texture, offset, extent); + luax_assert(L, readback); + + return luax_yieldpoll(L, luax_pollreadback, luax_waitreadback, luax_pushreadbackdata, readback); } static int l_lovrTextureSetPixels(lua_State* L) { diff --git a/src/api/l_physics.c b/src/api/l_physics.c index 2df3fd050..3f37b9f5c 100644 --- a/src/api/l_physics.c +++ b/src/api/l_physics.c @@ -343,6 +343,8 @@ extern const luaL_Reg lovrDistanceJoint[]; extern const luaL_Reg lovrHingeJoint[]; extern const luaL_Reg lovrSliderJoint[]; +extern const char* lovrWorldAsync[]; + static void luax_unref(void* object, uintptr_t userdata) { if (!userdata) return; lua_State* L = (lua_State*) userdata; diff --git a/src/api/l_task.c b/src/api/l_task.c new file mode 100644 index 000000000..6f3898761 --- /dev/null +++ b/src/api/l_task.c @@ -0,0 +1,368 @@ +#include "api.h" +#include "task/task.h" +#include "core/job.h" +#include "util.h" +#include +#include + +static void luax_pintask(lua_State* L, Task* task) { + lua_getfield(L, LUA_REGISTRYINDEX, "_lovrtasks"); + lua_pushlightuserdata(L, task); + lua_pushthread(L); + lua_rawset(L, -3); + lua_pop(L, 1); +} + +static void luax_unpintask(lua_State* L, Task* task) { + lua_getfield(L, LUA_REGISTRYINDEX, "_lovrtasks"); + lua_pushlightuserdata(L, task); + lua_pushnil(L); + lua_rawset(L, -3); + lua_pop(L, 1); +} + +int luax_yieldpoll(lua_State* L, fn_task* poll, fn_task* block, fn_continuation* continuation, void* context) { + Task* task = luax_getthreaddata(L); + + if (!task) { + if (block(&context)) { + return continuation ? continuation(L, context) : 0; + } else { + lua_pushstring(L, lovrGetError()); + return lua_error(L); + } + } + + lovrTaskPoll(task, poll, block, continuation, context); + luax_pintask(L, task); + return lua_yield(L, 0); +} + +static void taskRunner(void* arg) { + Task* task = arg; + + if (!task->fn(&task->context)) { + task->error = lovrStrdup(lovrGetError()); + } + + lovrTaskEnqueue(task); + atomic_fetch_sub(&task->deps, 1); +} + +int luax_yieldjob(lua_State* L, fn_task* fn, fn_continuation* continuation, void* context) { + Task* task = luax_getthreaddata(L); + + if (!task) { + if (fn(&context)) { + return continuation(L, context); + } else { + lua_pushstring(L, lovrGetError()); + return lua_error(L); + } + } + + task->fn = fn; + task->context = context; + task->continuation = continuation; + atomic_store(&task->deps, 1); + task->waiting = WAIT_JOB; + + if (!job_start(taskRunner, task)) { + task->waiting = WAIT_NONE; + atomic_store(&task->deps, 0); + if (fn(&context)) { + return continuation(L, context); + } else { + lua_pushstring(L, lovrGetError()); + return lua_error(L); + } + } else { + luax_pintask(L, task); + return lua_yield(L, 0); + } +} + +static int luax_runtask(Task* task, int n) { + lua_State* T = task->T; + + if (task->waiting) { + // Unpin from registry + luax_unpintask(T, task); + + // Remove it from the ready queue if it's there + lovrTaskDequeue(task); + + // Handle error: can't actually throw an error in T without Lua 5.2 continuations + if (task->error) { + task->waiting = WAIT_NONE; + lua_settop(T, 0); + lua_pushstring(T, task->error); + lovrTaskFinish(task); + return LUA_ERRRUN; + } + + // Set up resume values + if (task->waiting == WAIT_JOB || task->waiting == WAIT_POLL) { + n = task->continuation ? task->continuation(T, task->context) : 0; + } else { + n = 0; + + // Copy the first result from each dependency, replacing each coroutine with its result + int top = lua_gettop(T); + for (int i = 1; i < top; i++) { + lua_State* D = lua_tothread(T, i); + lua_pushvalue(D, 1); + lua_xmove(D, T, 1); + lua_replace(T, i); + n++; + } + + // ...Except for the last dependency. Copy ALL of its results instead + lua_State* D = lua_tothread(T, top); + int rest = lua_gettop(D); + luax_check(T, lua_checkstack(D, rest), "stack overflow"); + for (int i = 1; i <= rest; i++) { + lua_pushvalue(D, i); + } + lua_pop(T, 1); + lua_xmove(D, T, rest); + n += rest; + } + + task->waiting = WAIT_NONE; + } + + int status = luax_resume(T, n); + + // Handle error/completion + if (status != LUA_YIELD) { + if (status != LUA_OK) { + task->error = lovrStrdup(lua_tostring(T, -1)); + } + + lovrTaskFinish(task); + } + + return status; +} + +static int l_lovrTaskResume(lua_State* L) { + luaL_checktype(L, 1, LUA_TTHREAD); + lua_State* T = lua_tothread(L, 1); + Task* task = luax_getthreaddata(T); + + if (!task) { + task = lovrTaskCreate(T); + luax_setthreaddata(T, task); + } else if (task->complete) { + lua_pushnil(L); + lua_pushliteral(L, "already complete"); + return 2; + } else if (!lovrTaskIsReady(task)) { + lua_pushnil(L); + lua_pushliteral(L, "not ready"); + return 2; + } + + int n = 0; + + // If the task wasn't waiting on anything (it yielded with coroutine.yield), give it arguments + if (!task->waiting) { + n = lua_gettop(L) - 1; + lua_xmove(L, T, n); + } + + int status = luax_runtask(task, n); + + if (task->waiting) { + lua_pushboolean(L, true); + return 1; + } + + luax_setthreaddata(T, NULL); + lovrTaskDestroy(task); + + if (status == LUA_OK) { + lua_pushboolean(L, true); + int n = lua_gettop(T); + luax_check(L, lua_checkstack(T, n), "stack overflow"); + for (int i = 1; i <= n; i++) { + lua_pushvalue(T, i); + } + lua_xmove(T, L, n); + return n + 1; + } else if (status == LUA_YIELD) { + lua_pushboolean(L, true); + // It yielded with coroutine.yield, return the results it yielded with + int n = lua_gettop(T); + luax_check(L, lua_checkstack(T, n), "stack overflow"); + for (int i = 1; i <= n; i++) { + lua_pushvalue(T, i); + } + lua_xmove(T, L, n); + return n + 1; + } else { + lua_pushboolean(L, false); + lua_pushvalue(T, -1); + lua_xmove(T, L, 1); + return 2; + } +} + +static int l_lovrTaskIsWaiting(lua_State* L) { + luaL_checktype(L, 1, LUA_TTHREAD); + lua_State* T = lua_tothread(L, 1); + Task* task = luax_getthreaddata(T); + lua_pushboolean(L, task && task->deps > 0); + return 1; +} + +static int l_lovrTaskNext(lua_State* L) { + Task* task = lovrTaskModuleGetNext(); + if (task) { + lua_getfield(L, LUA_REGISTRYINDEX, "_lovrtasks"); + lua_pushlightuserdata(L, task); + lua_rawget(L, -2); + } else { + lua_pushnil(L); + } + return 1; +} + +static int l_lovrTaskPoll(lua_State* L) { + lua_pushvalue(L, lua_upvalueindex(1)); + return 1; +} + +static int luax_waittask(lua_State* T) { + Task* task = luax_getthreaddata(T); + luax_check(T, task, "Trying to wait on a coroutine that wasn't resumed with lovr.task.resume"); + + if (task->complete) { + return task->error ? LUA_ERRRUN : LUA_OK; + } + + if (task->waiting == WAIT_JOB) { + while (atomic_load(&task->deps) > 0) { + job_spin(); + } + } else if (task->waiting == WAIT_POLL) { + if (!task->block(&task->context)) { + task->error = lovrStrdup(lovrGetError()); + } + } else { + int n = lua_gettop(T); + for (int i = 1; i <= n; i++) { + luax_waittask(lua_tothread(T, i)); + } + } + + return luax_runtask(task, 0); +} + +static int l_lovrTaskWait(lua_State* L) { + Task* self = luax_getthreaddata(L); + + if (lua_istable(L, 1)) { + int length = luax_len(L, 1); + + for (int i = 1; i <= length; i++) { + lua_rawgeti(L, 1, i); + } + + lua_remove(L, 1); + } + + int n = lua_gettop(L); + + if (n == 0) { + return 0; + } + + if (self) { + for (int i = 1; i <= n; i++) { + luaL_checktype(L, i, LUA_TTHREAD); + lua_State* T = lua_tothread(L, i); + Task* task = luax_getthreaddata(T); + luax_check(T, task, "Trying to wait on a coroutine that wasn't resumed with lovr.task.resume"); + luax_assert(L, lovrTaskAddDependency(self, task)); + } + + // Only yield if we're actually waiting on something. If everything was already complete, fall + // through to the synchronous path, which handles errors and gathers results + if (self->waiting == WAIT_TASK) { + luax_pintask(L, self); + return lua_yield(L, n); + } + } + + for (int i = 1; i <= n; i++) { + luaL_checktype(L, i, LUA_TTHREAD); + lua_State* T = lua_tothread(L, i); + + for (;;) { + int status = luax_waittask(T); + + if (status == LUA_OK) { + break; + } else if (status != LUA_YIELD) { + lua_pushboolean(L, false); + lua_pushvalue(T, -1); + lua_xmove(T, L, 1); + return 2; + } + } + } + + int results = 0; + + for (int i = 1; i <= n; i++) { + lua_State* T = lua_tothread(L, i); + + // Last task returns all args, other tasks return first arg + if (i < n) { + lua_pushvalue(T, 1); + lua_xmove(T, L, 1); + lua_replace(L, i); + results++; + } else { + int rest = lua_gettop(T); + for (int j = 1; j <= rest; j++) { + lua_pushvalue(T, j); + } + lua_pop(L, 1); + lua_xmove(T, L, rest); + results += rest; + } + } + + lua_pushboolean(L, true); + lua_insert(L, 1); + + return results + 1; +} + +extern const luaL_Reg lovrTask[]; + +static const luaL_Reg lovrTaskModule[] = { + { "resume", l_lovrTaskResume }, + { "isWaiting", l_lovrTaskIsWaiting }, + { "wait", l_lovrTaskWait }, + { NULL, NULL } +}; + +int luaopen_lovr_task(lua_State* L) { + lua_newtable(L); + luax_register(L, lovrTaskModule); + + lua_newtable(L); + lua_setfield(L, LUA_REGISTRYINDEX, "_lovrtasks"); + + lua_pushcfunction(L, l_lovrTaskNext); + lua_pushcclosure(L, l_lovrTaskPoll, 1); + lua_setfield(L, -2, "poll"); + + lovrTaskModuleInit(); + luax_atexit(L, lovrTaskModuleDestroy); + return 1; +} diff --git a/src/api/l_thread.c b/src/api/l_thread.c index 118508bc5..154ffb9ed 100644 --- a/src/api/l_thread.c +++ b/src/api/l_thread.c @@ -4,10 +4,11 @@ #include "core/os.h" #include "util.h" #include +#include #include #include -static char* threadRunner(Thread* thread, Blob* body, Variant* arguments, uint32_t argumentCount) { +static char* threadBody(Thread* thread, Blob* body, Variant* arguments, uint32_t argumentCount) { lua_State* L = luaL_newstate(); luaL_openlibs(L); luax_preload(L); @@ -55,7 +56,7 @@ static int l_lovrThreadNewThread(lua_State* L) { } else { lovrRetain(blob); } - Thread* thread = lovrThreadCreate(threadRunner, blob); + Thread* thread = lovrThreadCreate(threadBody, blob); luax_pushtype(L, Thread, thread); lovrRelease(thread, lovrThreadDestroy); lovrRelease(blob, lovrBlobDestroy); @@ -77,10 +78,197 @@ static int l_lovrThreadGetChannel(lua_State* L) { return 1; } +typedef struct RunContext { + struct RunContext* next; + arr_t(char) code; + lua_CFunction function; + uint32_t argumentCount; + uint32_t resultCount; + Variant* arguments; + Variant* results; + char* error; +} RunContext; + +static thread_local RunContext* contextPool; +static thread_local lua_State* workerState; + +static void onWorkerQuit(void) { + if (workerState) { + lua_close(workerState); + workerState = NULL; + } + + while (contextPool) { + RunContext* context = contextPool; + contextPool = context->next; + arr_free(&context->code); + lovrFree(context); + } +} + +static bool luax_runlua(void** arg) { + RunContext* context = *arg; + lua_State* L = workerState; + + if (!L) { + L = luaL_newstate(); + luaL_openlibs(L); + luax_preload(L); + workerState = L; + + lua_newtable(L); + lua_setfield(L, LUA_REGISTRYINDEX, "_lovrchunks"); + } + + int base = lua_gettop(L); + lua_pushcfunction(L, luax_getstack); + + if (context->function) { + lua_pushcfunction(L, context->function); + } else { + lua_getfield(L, LUA_REGISTRYINDEX, "_lovrchunks"); + lua_pushlstring(L, context->code.data, context->code.length); + lua_rawget(L, -2); + + if (lua_isfunction(L, -1)) { + lua_remove(L, -2); + } else { + if (luax_loadbufferx(L, context->code.data, context->code.length, "", "b")) { + for (uint32_t i = 0; i < context->argumentCount; i++) { + lovrVariantDestroy(&context->arguments[i]); + } + lovrSetError(lua_tostring(L, -1)); + lua_settop(L, base); + return false; + } + + lua_replace(L, -2); + lua_pushlstring(L, context->code.data, context->code.length); + lua_pushvalue(L, -2); + lua_rawset(L, -4); + lua_remove(L, -2); + } + } + + for (uint32_t i = 0; i < context->argumentCount; i++) { + luax_pushvariant(L, &context->arguments[i]); + lovrVariantDestroy(&context->arguments[i]); + } + + if (lua_pcall(L, context->argumentCount, LUA_MULTRET, base + 1) != LUA_OK) { + lovrSetError(lua_tostring(L, -1)); + lua_settop(L, base); + return false; + } + + int n = lua_gettop(L) - base - 1; + + if (n > 0) { + context->resultCount = n; + context->results = lovrRealloc(context->arguments, n * sizeof(Variant)); + context->argumentCount = 0; + context->arguments = NULL; + for (int i = 0; i < n; i++) { + luax_checkvariant(L, base + 2 + i, &context->results[i]); + } + } + + lua_settop(L, base); + + return true; +} + +static int luax_pushresults(lua_State* L, void* arg) { + RunContext* context = arg; + + if (context->error) { + lua_pushstring(L, context->error); + lovrFree(context->error); + lovrFree(context->arguments); + context->next = contextPool; + contextPool = context; + return lua_error(L); + } + + int n = context->resultCount; + + for (int i = 0; i < n; i++) { + luax_pushvariant(L, &context->results[i]); + lovrVariantDestroy(&context->results[i]); + } + + lovrFree(context->arguments); + lovrFree(context->results); + context->next = contextPool; + contextPool = context; + return n; +} + +static int writer(lua_State* L, const void* data, size_t size, void* userdata) { + RunContext* context = userdata; + arr_append(&context->code, data, size); + return 0; +} + +int luax_callthread(lua_State* L, int n) { + RunContext* context = contextPool; + + if (context) { + contextPool = context->next; + arr_clear(&context->code); + context->function = NULL; + context->argumentCount = 0; + context->resultCount = 0; + context->error = NULL; + } else { + context = lovrCalloc(sizeof(RunContext)); + arr_init(&context->code); + } + + int function = lua_gettop(L) - n; + + if (lua_iscfunction(L, function)) { + context->function = lua_tocfunction(L, function); + } else { + luaL_checktype(L, function, LUA_TFUNCTION); + lua_getfield(L, LUA_REGISTRYINDEX, "_lovrbytecode"); + lua_pushvalue(L, function); + lua_rawget(L, -2); + + if (lua_isnil(L, -1)) { + lua_pushvalue(L, function); + luax_check(L, !lua_dump(L, writer, context), "Failed to dump function to bytecode"); + lua_pushlstring(L, context->code.data, context->code.length); + lua_rawset(L, -4); + lua_pop(L, 2); + } else { + size_t length; + const char* code = lua_tolstring(L, -1, &length); + arr_append(&context->code, code, length); + lua_pop(L, 2); + } + } + + if (n > 0) { + context->argumentCount = n; + context->arguments = lovrMalloc(n * sizeof(Variant)); + for (int i = 0; i < n; i++) { + luax_checkvariant(L, i + 2, &context->arguments[i]); + } + } + + return luax_yieldjob(L, luax_runlua, luax_pushresults, context); +} + +static int l_lovrThreadRun(lua_State* L) { + return luax_callthread(L, lua_gettop(L) - 1); +} + static const luaL_Reg lovrThreadModule[] = { { "newThread", l_lovrThreadNewThread }, { "newChannel", l_lovrThreadNewChannel }, { "getChannel", l_lovrThreadGetChannel }, + { "run", l_lovrThreadRun }, { NULL, NULL } }; @@ -109,7 +297,10 @@ int luaopen_lovr_thread(lua_State* L) { } lua_pop(L, 1); - lovrThreadModuleInit(workers); + lua_newtable(L); + lua_setfield(L, LUA_REGISTRYINDEX, "_lovrbytecode"); + + lovrThreadModuleInit(workers, onWorkerQuit); luax_atexit(L, lovrThreadModuleDestroy); return 1; } diff --git a/src/api/l_timer.c b/src/api/l_timer.c index c30a93780..e92c3ce81 100644 --- a/src/api/l_timer.c +++ b/src/api/l_timer.c @@ -26,10 +26,21 @@ static int l_lovrTimerStep(lua_State* L) { return 1; } +static bool luax_polltime(void** context) { + double timeout = ((union { double f64; void* p; }) { .p = *context }).f64; + return lovrTimerGetTime() >= timeout; +} + +static bool luax_waittime(void** context) { + double timeout = ((union { double f64; void* p; }) { .p = *context }).f64; + lovrTimerSleep(timeout - lovrTimerGetTime()); + return true; +} + static int l_lovrTimerSleep(lua_State* L) { double duration = luaL_checknumber(L, 1); - lovrTimerSleep(duration); - return 0; + void* timeout = ((union { double f64; void* p; }) { .f64 = lovrTimerGetTime() + duration }).p; + return luax_yieldpoll(L, luax_polltime, luax_waittime, NULL, timeout); } static const luaL_Reg lovrTimer[] = { diff --git a/src/core/job.c b/src/core/job.c index cffc71d6a..6ab279112 100644 --- a/src/core/job.c +++ b/src/core/job.c @@ -16,9 +16,10 @@ static struct { atomic_uint head; atomic_uint tail; job jobs[MAX_JOBS]; - void (*setupWorker)(uint32_t id); thrd_t workers[MAX_WORKERS]; uint32_t workerCount; + fn_hook* workerInit; + fn_hook* workerQuit; cnd_t hasJob; mtx_t lock; bool quit; @@ -32,8 +33,10 @@ static void runJob(void) { } static int workerLoop(void* arg) { - if (state.setupWorker) { - state.setupWorker((uint32_t) (uintptr_t) arg); + uint32_t id = (uint32_t) (uintptr_t) arg; + + if (state.workerInit) { + state.workerInit(id); } for (;;) { @@ -51,14 +54,20 @@ static int workerLoop(void* arg) { } mtx_unlock(&state.lock); + + if (state.workerQuit) { + state.workerQuit(id); + } + return 0; } -bool job_init(uint32_t count, void (*setupWorker)(uint32_t id)) { +bool job_init(uint32_t count, fn_hook* init, fn_hook* quit) { mtx_init(&state.lock, mtx_plain); cnd_init(&state.hasJob); - state.setupWorker = setupWorker; + state.workerInit = init; + state.workerQuit = quit; if (count > MAX_WORKERS) count = MAX_WORKERS; for (uint32_t i = 0; i < count; i++, state.workerCount++) { if (thrd_create(&state.workers[i], workerLoop, (void*) (uintptr_t) i) != thrd_success) { diff --git a/src/core/job.h b/src/core/job.h index 1b4817a21..f241ec453 100644 --- a/src/core/job.h +++ b/src/core/job.h @@ -4,8 +4,9 @@ #pragma once typedef void fn_job(void* arg); +typedef void fn_hook(uint32_t worker); -bool job_init(uint32_t workerCount, void (*setupWorker)(uint32_t index)); +bool job_init(uint32_t workerCount, fn_hook* init, fn_hook* quit); void job_destroy(void); bool job_start(fn_job* fn, void* arg); void job_spin(void); diff --git a/src/modules/graphics/graphics.c b/src/modules/graphics/graphics.c index a2d89affb..f1892d67c 100644 --- a/src/modules/graphics/graphics.c +++ b/src/modules/graphics/graphics.c @@ -892,6 +892,7 @@ void lovrGraphicsDestroy(void) { } if (state.timestamps) gpu_tally_destroy(state.timestamps); lovrFree(state.timestamps); + if (state.window) lovrFree(state.window->sync); lovrRelease(state.window, lovrTextureDestroy); lovrRelease(state.windowPass, lovrPassDestroy); lovrRelease(state.defaultFont, lovrFontDestroy); @@ -2261,28 +2262,6 @@ const BufferInfo* lovrBufferGetInfo(Buffer* buffer) { return &buffer->info; } -void* lovrBufferGetData(Buffer* buffer, uint32_t offset, uint32_t extent) { - if (extent == ~0u) extent = buffer->info.size - offset; - lovrCheck(offset + extent <= buffer->info.size, "Buffer read range goes past the end of the Buffer"); - - mtx_lock(&state.lock); - - gpu_barrier barrier = syncStream(buffer->sync, GPU_PHASE_COPY, GPU_CACHE_TRANSFER_READ); - gpu_sync(state.stream, &barrier, 1); - - BufferView view = getBuffer(GPU_BUFFER_DOWNLOAD, extent, 4); - if (!view.buffer) return mtx_unlock(&state.lock), NULL; - - gpu_copy_buffers(state.stream, buffer->gpu, view.buffer, buffer->base + offset, view.offset, extent); - mtx_unlock(&state.lock); - - if (!lovrGraphicsSubmit(NULL, 0) || !lovrGraphicsWait()) { - return NULL; - } - - return view.pointer; -} - void* lovrBufferSetData(Buffer* buffer, uint32_t offset, uint32_t extent) { if (extent == ~0u) extent = buffer->info.size - offset; lovrCheck(offset + extent <= buffer->info.size, "Attempt to write past the end of the Buffer"); @@ -2820,41 +2799,6 @@ const TextureInfo* lovrTextureGetInfo(Texture* texture) { return &texture->info; } -Image* lovrTextureGetPixels(Texture* texture, uint32_t offset[4], uint32_t extent[3]) { - if (extent[0] == ~0u) extent[0] = texture->info.width - offset[0]; - if (extent[1] == ~0u) extent[1] = texture->info.height - offset[1]; - lovrCheck(extent[2] == 1, "Currently only a single layer can be read from a Texture"); - lovrCheck(texture->info.usage & TEXTURE_TRANSFER, "Texture must be created with the 'transfer' usage to read from it"); - lovrCheck(texture->info.samples == 1, "Can't get pixels of a multisampled texture"); - if (!checkTextureBounds(&texture->info, offset, extent)) return NULL; - - mtx_lock(&state.lock); - - gpu_barrier barrier = syncStream(texture->sync, GPU_PHASE_COPY, GPU_CACHE_TRANSFER_READ); - gpu_sync(state.stream, &barrier, 1); - - uint32_t rootOffset[4] = { offset[0], offset[1], offset[2] + texture->baseLayer, offset[3] + texture->baseLevel }; - - BufferView view = getBuffer(GPU_BUFFER_DOWNLOAD, measureTexture(texture->info.format, extent[0], extent[1], 1), 64); - if (!view.buffer) return mtx_unlock(&state.lock), NULL; - - gpu_copy_texture_buffer(state.stream, texture->root->gpu, view.buffer, rootOffset, view.offset, extent); - mtx_unlock(&state.lock); - - if (!lovrGraphicsSubmit(NULL, 0) || !lovrGraphicsWait()) { - return NULL; - } - - Image* image = lovrImageCreateRaw(extent[0], extent[1], texture->info.format, texture->info.srgb); - - if (image) { - void* data = lovrImageGetLayerData(image, offset[3], offset[2]); - memcpy(data, view.pointer, view.extent); - } - - return image; -} - bool lovrTextureSetPixels(Texture* texture, Image* image, uint32_t dstOffset[4], uint32_t srcOffset[4], uint32_t extent[3]) { TextureFormat format = texture->info.format; if (extent[0] == ~0u) extent[0] = MIN(texture->info.width - dstOffset[0], lovrImageGetWidth(image, srcOffset[3]) - srcOffset[0]); @@ -4772,6 +4716,10 @@ void lovrMeshDestroy(void* ref) { lovrFree(mesh); } +MeshStorage lovrMeshGetStorage(Mesh* mesh) { + return mesh->storage; +} + const DataField* lovrMeshGetVertexFormat(Mesh* mesh) { return mesh->vertexBuffer->info.format; } @@ -4814,12 +4762,8 @@ void* lovrMeshGetVertices(Mesh* mesh, uint32_t index, uint32_t count) { const DataField* format = lovrMeshGetVertexFormat(mesh); if (count == ~0u) count = format->length - index; lovrCheck(index < format->length && count <= format->length - index, "Mesh vertex range [%d,%d] overflows mesh capacity", index + 1, index + 1 + count - 1); - - if (mesh->storage == MESH_CPU) { - return (char*) mesh->vertices + index * format->stride; - } else { - return lovrBufferGetData(mesh->vertexBuffer, index * format->stride, count * format->stride); - } + lovrCheck(mesh->storage == MESH_CPU, "Can't get vertices of GPU mesh"); + return (char*) mesh->vertices + index * format->stride; } void* lovrMeshSetVertices(Mesh* mesh, uint32_t index, uint32_t count) { @@ -4837,21 +4781,16 @@ void* lovrMeshSetVertices(Mesh* mesh, uint32_t index, uint32_t count) { } bool lovrMeshGetIndices(Mesh* mesh, void** indices, uint32_t* count, DataType* type) { - if (mesh->indexCount == 0 || !mesh->indexBuffer) { + lovrCheck(mesh->storage == MESH_CPU, "Can't get indices of GPU mesh"); + + if (mesh->indexCount == 0 || !mesh->indexBuffer || mesh->storage == MESH_GPU) { *indices = NULL; return true; } *count = mesh->indexCount; *type = mesh->indexBuffer->info.format->type; - - if (mesh->storage == MESH_CPU) { - *indices = mesh->indices; - return true; - } else { - *indices = lovrBufferGetData(mesh->indexBuffer, 0, mesh->indexCount * mesh->indexBuffer->info.format->stride); - return *indices != NULL; - } + *indices = mesh->indices; } void* lovrMeshSetIndices(Mesh* mesh, uint32_t count, DataType type) { @@ -6185,25 +6124,21 @@ Readback* lovrReadbackCreateTexture(Texture* texture, uint32_t offset[4], uint32 if (extent[0] == ~0u) extent[0] = texture->info.width - offset[0]; if (extent[1] == ~0u) extent[1] = texture->info.height - offset[1]; lovrCheck(extent[2] == 1, "Currently, only one layer can be read from a Texture"); - lovrCheck(texture->root == texture, "Can not read from a Texture view"); - lovrCheck(texture->info.samples == 1, "Can not read from a multisampled texture"); + lovrCheck(texture->info.samples == 1, "Can't get pixels of a multisampled texture"); lovrCheck(texture->info.usage & TEXTURE_TRANSFER, "Texture must be created with the 'transfer' usage to read from it"); - checkTextureBounds(&texture->info, offset, extent); + if (!checkTextureBounds(&texture->info, offset, extent)) return NULL; + Image* image = lovrImageCreateRaw(extent[0], extent[1], texture->info.format, texture->info.srgb); + lovrAssert(image, "Failed to create image: %s", lovrGetError()); mtx_lock(&state.lock); BufferView view = getBuffer(GPU_BUFFER_DOWNLOAD, measureTexture(texture->info.format, extent[0], extent[1], 1), 64); if (!view.buffer) return mtx_unlock(&state.lock), NULL; - Image* image = lovrImageCreateRaw(extent[0], extent[1], texture->info.format, texture->info.srgb); - if (!image) { - mtx_unlock(&state.lock); - lovrSetError("Failed to create image: %s", lovrGetError()); - return NULL; - } Readback* readback = lovrReadbackCreate(READBACK_TEXTURE); readback->image = image; readback->view = view; gpu_barrier barrier = syncStream(texture->sync, GPU_PHASE_COPY, GPU_CACHE_TRANSFER_READ); gpu_sync(state.stream, &barrier, 1); - gpu_copy_texture_buffer(state.stream, texture->gpu, readback->view.buffer, offset, readback->view.offset, extent); + uint32_t rootOffset[4] = { offset[0], offset[1], offset[2] + texture->baseLayer, offset[3] + texture->baseLayer }; + gpu_copy_texture_buffer(state.stream, texture->gpu, readback->view.buffer, rootOffset, readback->view.offset, extent); mtx_unlock(&state.lock); return readback; } diff --git a/src/modules/graphics/graphics.h b/src/modules/graphics/graphics.h index 0017dd706..f8a318a91 100644 --- a/src/modules/graphics/graphics.h +++ b/src/modules/graphics/graphics.h @@ -196,7 +196,6 @@ typedef struct { Buffer* lovrBufferCreate(const BufferInfo* info, void** data); void lovrBufferDestroy(void* ref); const BufferInfo* lovrBufferGetInfo(Buffer* buffer); -void* lovrBufferGetData(Buffer* buffer, uint32_t offset, uint32_t extent); void* lovrBufferSetData(Buffer* buffer, uint32_t offset, uint32_t extent); bool lovrBufferCopy(Buffer* src, Buffer* dst, uint32_t srcOffset, uint32_t dstOffset, uint32_t extent); bool lovrBufferClear(Buffer* buffer, uint32_t offset, uint32_t extent, uint32_t value); @@ -255,7 +254,6 @@ Texture* lovrTextureCreate(const TextureInfo* info); Texture* lovrTextureCreateView(Texture* parent, const TextureViewInfo* info); void lovrTextureDestroy(void* ref); const TextureInfo* lovrTextureGetInfo(Texture* texture); -struct Image* lovrTextureGetPixels(Texture* texture, uint32_t offset[4], uint32_t extent[3]); bool lovrTextureSetPixels(Texture* texture, struct Image* image, uint32_t texOffset[4], uint32_t imgOffset[4], uint32_t extent[3]); bool lovrTextureCopy(Texture* src, Texture* dst, uint32_t srcOffset[4], uint32_t dstOffset[4], uint32_t extent[3]); bool lovrTextureBlit(Texture* src, Texture* dst, uint32_t srcOffset[4], uint32_t dstOffset[4], uint32_t srcExtent[3], uint32_t dstExtent[3], FilterMode filter); @@ -468,6 +466,7 @@ typedef struct { Mesh* lovrMeshCreate(const MeshInfo* info, void** data); void lovrMeshDestroy(void* ref); +MeshStorage lovrMeshGetStorage(Mesh* mesh); const DataField* lovrMeshGetVertexFormat(Mesh* mesh); Buffer* lovrMeshGetVertexBuffer(Mesh* mesh); Buffer* lovrMeshGetIndexBuffer(Mesh* mesh); diff --git a/src/modules/task/task.c b/src/modules/task/task.c new file mode 100644 index 000000000..ed0575375 --- /dev/null +++ b/src/modules/task/task.c @@ -0,0 +1,176 @@ +#include "task/task.h" +#include "core/job.h" +#include "util.h" +#include +#include + +static atomic_uint ref; + +static struct { + Waiter* waiters; + _Atomic(Task*) pending; + Task* queue; + Task* polls; + Task* pool; +} state; + +bool lovrTaskModuleInit(void) { + if (!lovrModuleAcquire(&ref)) return false; + lovrModuleReady(&ref); + return true; +} + +void lovrTaskModuleDestroy(void) { + if (!lovrModuleRelease(&ref)) return; + while (state.pool) { + Task* task = state.pool; + state.pool = task->next; + lovrFree(task); + } + while (state.waiters) { + Waiter* waiter = state.waiters; + state.waiters = waiter->next; + lovrFree(waiter); + } + memset(&state, 0, sizeof(state)); + lovrModuleReset(&ref); +} + +Task* lovrTaskModuleGetNext(void) { + for (;;) { + while (state.queue) { + Task* task = state.queue; + state.queue = task->next; + if (!task->dequeued) { + return task; + } + } + + Task* task = atomic_exchange(&state.pending, NULL); + + if (!task) { + break; + } + + while (task) { + Task* next = task->next; + task->next = state.queue; + state.queue = task; + task = next; + } + } + + Task** list = &state.polls; + while (*list) { + Task* task = *list; + if (task->fn(&task->context)) { + *list = task->next; + if (atomic_fetch_sub(&task->deps, 1) == 1) { + return task; + } + } else { + list = &task->next; + } + } + + return NULL; +} + +// Task + +Task* lovrTaskCreate(struct lua_State* T) { + Task* task = state.pool; + + if (task) { + state.pool = task->next; + memset(task, 0, sizeof(Task)); + } else { + task = lovrCalloc(sizeof(Task)); + } + + task->T = T; + return task; +} + +void lovrTaskDestroy(Task* task) { + while (task->waiters) { + Waiter* waiter = task->waiters; + task->waiters = waiter->next; + waiter->next = state.waiters; + state.waiters = waiter; + } + lovrFree(task->error); + task->next = state.pool; + state.pool = task; +} + +bool lovrTaskIsReady(Task* task) { + return !task->complete && task->deps == 0; +} + +void lovrTaskEnqueue(Task* task) { + task->dequeued = false; + task->next = atomic_load(&state.pending); + while (!atomic_compare_exchange_strong(&state.pending, &task->next, task)); +} + +void lovrTaskDequeue(Task* task) { + task->dequeued = true; +} + +void lovrTaskPoll(Task* task, fn_task* poll, fn_task* block, fn_continuation* continuation, void* context) { + task->fn = poll; + task->block = block; + task->context = context; + task->continuation = continuation; + atomic_store(&task->deps, 1); + task->waiting = WAIT_POLL; + task->next = state.polls; + state.polls = task; +} + +void lovrTaskFinish(Task* task) { + while (task->waiters) { + Waiter* waiter = task->waiters; + task->waiters = waiter->next; + + // If this task failed, copy the error to any dependents + if (task->error && !waiter->task->error) { + waiter->task->error = lovrStrdup(task->error); + } + + if (atomic_fetch_sub(&waiter->task->deps, 1) == 1) { + lovrTaskEnqueue(waiter->task); + } + + waiter->next = state.waiters; + state.waiters = waiter; + } + + task->complete = true; +} + +bool lovrTaskAddDependency(Task* task, Task* dep) { + if (dep->complete) { + if (dep->error && !task->error) { + task->error = lovrStrdup(dep->error); + } + return true; + } + + lovrAssert(task->deps < ~0u, "Task is waiting on too many other tasks"); + Waiter* waiter = state.waiters; + + if (waiter) { + state.waiters = waiter->next; + } else { + waiter = lovrMalloc(sizeof(Waiter)); + } + + waiter->next = dep->waiters; + waiter->task = task; + dep->waiters = waiter; + task->waiting = WAIT_TASK; + atomic_fetch_add(&task->deps, 1); + return true; +} diff --git a/src/modules/task/task.h b/src/modules/task/task.h new file mode 100644 index 000000000..ac778ced4 --- /dev/null +++ b/src/modules/task/task.h @@ -0,0 +1,53 @@ +#include +#include +#include + +#pragma once + +struct lua_State; +typedef bool fn_task(void** context); +typedef int fn_continuation(struct lua_State* L, void* context); + +typedef struct Task Task; + +bool lovrTaskModuleInit(void); +void lovrTaskModuleDestroy(void); +Task* lovrTaskModuleGetNext(void); + +// Task + +typedef struct Waiter { + struct Waiter* next; + Task* task; +} Waiter; + +typedef enum { + WAIT_NONE, + WAIT_TASK, + WAIT_POLL, + WAIT_JOB +} WaitType; + +struct Task { + bool complete; + bool dequeued; + WaitType waiting; + atomic_uint deps; + struct Task* next; + struct lua_State* T; + fn_task* fn; + fn_task* block; + fn_continuation* continuation; + void* context; + Waiter* waiters; + char* error; +}; + +Task* lovrTaskCreate(struct lua_State* T); +void lovrTaskDestroy(Task* task); +bool lovrTaskIsReady(Task* task); +void lovrTaskEnqueue(Task* task); +void lovrTaskDequeue(Task* task); +void lovrTaskPoll(Task* task, fn_task* poll, fn_task* block, fn_continuation* continuation, void* context); +void lovrTaskFinish(Task* task); +bool lovrTaskAddDependency(Task* task, Task* dependency); diff --git a/src/modules/thread/thread.c b/src/modules/thread/thread.c index d14223d7e..8571262d2 100644 --- a/src/modules/thread/thread.c +++ b/src/modules/thread/thread.c @@ -39,24 +39,33 @@ static atomic_uint ref; static struct { uint32_t workers; + void (*onWorkerQuit)(void); mtx_t channelLock; map_t channels; } state; -static void setupWorker(uint32_t id) { +static void workerInit(uint32_t id) { lovrProfileSetThreadName("Worker"); os_thread_set_name("Worker"); } -bool lovrThreadModuleInit(int32_t workers) { +static void workerQuit(uint32_t id) { + if (state.onWorkerQuit) { + state.onWorkerQuit(); + } +} + +bool lovrThreadModuleInit(int32_t workers, void (*onWorkerQuit)(void)) { if (!lovrModuleAcquire(&ref)) return true; + mtx_init(&state.channelLock, mtx_plain); map_init(&state.channels, 0); uint32_t cores = os_get_core_count(); if (workers < 0) workers += cores; state.workers = MAX(workers, 0); - job_init(state.workers, setupWorker); + state.onWorkerQuit = onWorkerQuit; + job_init(state.workers, workerInit, workerQuit); lovrModuleReady(&ref); return true; diff --git a/src/modules/thread/thread.h b/src/modules/thread/thread.h index 3f5f7143b..8e2dcf5c9 100644 --- a/src/modules/thread/thread.h +++ b/src/modules/thread/thread.h @@ -14,7 +14,7 @@ union Variant; typedef struct Thread Thread; typedef struct Channel Channel; -bool lovrThreadModuleInit(int32_t workers); +bool lovrThreadModuleInit(int32_t workers, void (*onWorkerQuit)(void)); void lovrThreadModuleDestroy(void); uint32_t lovrThreadGetWorkerCount(void); struct Channel* lovrThreadGetChannel(const char* name); diff --git a/src/util.h b/src/util.h index dd5b9b1d6..108955a2d 100644 --- a/src/util.h +++ b/src/util.h @@ -95,7 +95,7 @@ void lovrLog(int level, const char* tag, const char* format, ...); #define arr_expand(a, n) arr_reserve(a, (a)->length + n) #define arr_push(a, x) arr_reserve(a, (a)->length + 1), (a)->data[(a)->length] = x, (a)->length++ #define arr_pop(a) (a)->data[--(a)->length] -#define arr_append(a, p, n) arr_reserve(a, (a)->length + n), memcpy((a)->data + (a)->length, p, n * sizeof(*(p))), (a)->length += n +#define arr_append(a, p, n) arr_reserve(a, (a)->length + n), memcpy((a)->data + (a)->length, p, n * sizeof(*(a)->data)), (a)->length += n #define arr_splice(a, i, n) memmove((a)->data + (i), (a)->data + ((i) + n), ((a)->length - (i) - (n)) * sizeof(*(a)->data)), (a)->length -= n #define arr_clear(a) (a)->length = 0