diff --git a/symforce/caspar/source/runtime/pybind_array_tools.cc b/symforce/caspar/source/runtime/pybind_array_tools.cc index ec032103..9e43f198 100644 --- a/symforce/caspar/source/runtime/pybind_array_tools.cc +++ b/symforce/caspar/source/runtime/pybind_array_tools.cc @@ -169,6 +169,23 @@ void AssertUint2Vec(const py::object& obj) { Assert2DNxk(obj, 2); } +int GetDeviceId(const py::object& obj) { + try { + auto interface = obj.attr("__cuda_array_interface__").cast(); + auto data = interface["data"].cast(); + void* ptr = reinterpret_cast(data[0].cast()); + cudaPointerAttributes attrs; + cudaError_t err = cudaPointerGetAttributes(&attrs, ptr); + if (err != cudaSuccess) { + cudaGetLastError(); + return -1; + } + return attrs.device; + } catch (...) { + return -1; // Fallback if interface or attributes aren't available + } +} + float* AsFloatPtr(const py::object& obj) { AssertFloatVec(obj); py::tuple data = GetInterface(obj)["data"].cast(); diff --git a/symforce/caspar/source/runtime/pybind_array_tools.h b/symforce/caspar/source/runtime/pybind_array_tools.h index c6f77fc5..96445288 100644 --- a/symforce/caspar/source/runtime/pybind_array_tools.h +++ b/symforce/caspar/source/runtime/pybind_array_tools.h @@ -30,6 +30,8 @@ void AssertDeviceMemory(const py::object& obj); void AssertNumRowsEquals(const py::object& obj, size_t n); void AssertNumColsEquals(const py::object& obj, size_t n); +int GetDeviceId(const py::object& obj); + float* AsFloatPtr(const py::object& obj); double* AsDoublePtr(const py::object& obj); int* AsIntPtr(const py::object& obj); diff --git a/symforce/caspar/source/templates/caspar_mappings_pybinding.h.jinja b/symforce/caspar/source/templates/caspar_mappings_pybinding.h.jinja index 4c1e7b11..05f250c9 100644 --- a/symforce/caspar/source/templates/caspar_mappings_pybinding.h.jinja +++ b/symforce/caspar/source/templates/caspar_mappings_pybinding.h.jinja @@ -27,6 +27,7 @@ void add_casmappings_pybindings(pybind11::module_ module) { throw std::runtime_error( "The caspar data must have at least as many columns as stacked_data has rows."); } + cudaSetDevice(GetDeviceId(stacked_data)); {{nodetype.__name__}}StackedToCaspar( As{{caslib.storage_t.capitalize()}}Ptr(stacked_data), As{{caslib.storage_t.capitalize()}}Ptr(cas_data), cas_stride, 0, num_objects); }); @@ -45,7 +46,7 @@ void add_casmappings_pybindings(pybind11::module_ module) { throw std::runtime_error( "The caspar data must have at least as many columns as stacked_data has rows."); } - + cudaSetDevice(GetDeviceId(cas_data)); {{nodetype.__name__}}CasparToStacked( As{{caslib.storage_t.capitalize()}}Ptr(cas_data), As{{caslib.storage_t.capitalize()}}Ptr(stacked_data), cas_stride, 0, num_objects); }); diff --git a/symforce/caspar/source/templates/lib.pyi.jinja b/symforce/caspar/source/templates/lib.pyi.jinja index 2d75b55a..a208c4de 100644 --- a/symforce/caspar/source/templates/lib.pyi.jinja +++ b/symforce/caspar/source/templates/lib.pyi.jinja @@ -68,6 +68,7 @@ class {{solver.struct_name}}: {% for thing in solver.size_contributors %} {{num_arg_key(thing)}}: int = 0, {% endfor %} + device_id: int = 0, ): ... def set_params(self, params: SolverParams) -> None: diff --git a/symforce/caspar/source/templates/solver.cc.jinja b/symforce/caspar/source/templates/solver.cc.jinja index 0804e72f..4379d6e8 100644 --- a/symforce/caspar/source/templates/solver.cc.jinja +++ b/symforce/caspar/source/templates/solver.cc.jinja @@ -68,10 +68,12 @@ namespace caspar { {{ solver.struct_name }}::{{ solver.struct_name }}( const SolverParams ¶ms, {% for thing in solver.size_contributors %} - size_t {{num_arg_key(thing)}}{{ ", " if not loop.last else "" }} + size_t {{num_arg_key(thing)}}{{ ", " }} {% endfor %} + int device_id ) : params_(params), + device_id_(device_id), {% for thing in solver.size_contributors %} {{num_key(thing)}}({{num_arg_key(thing)}}), {{num_max_key(thing)}}({{num_arg_key(thing)}}){{ ", " if not loop.last else "" }} @@ -85,6 +87,20 @@ namespace caspar { throw std::runtime_error("params.diag_init must be positive"); } allocation_size_ = get_nbytes(); + + if (device_id_ < 0) { + throw std::runtime_error("Invalid CUDA device id: " + std::to_string(device_id_)); + } + if (device_id_ != 0) { + int deviceCount; + cudaGetDeviceCount(&deviceCount); + if (deviceCount <= device_id_) { + throw std::runtime_error("CUDA detected " + std::to_string(deviceCount) + + " devices, but device " + std::to_string(device_id_) + + " was requested (0-indexed)"); + } + } + cudaSetDevice(device_id_); cudaMalloc(&origin_ptr_, allocation_size_); size_t offset = 0; @@ -97,6 +113,7 @@ namespace caspar { } {{ solver.struct_name }}::~{{ solver.struct_name }}(){ + cudaSetDevice(device_id_); cudaFree(origin_ptr_); } @@ -110,6 +127,7 @@ size_t {{ solver.struct_name }}::get_allocation_size(){ SolveResult {{ solver.struct_name }}::solve(bool print_progress, bool verbose_logging) { + cudaSetDevice(device_id_); SolveResult result; result.exit_reason = ExitReason::MAX_ITERATIONS; {{solver.linear_t}} score_best; @@ -123,6 +141,7 @@ SolveResult {{ solver.struct_name }}::solve(bool print_progress, bool verbose_lo std::chrono::time_point t0 = std::chrono::steady_clock::now(); std::chrono::time_point t_prev = t0; score_best = DoResJacFirst(); + result.initial_score = score_best; if (print_progress) { printf(" score_init: % .6e\n", score_best); } @@ -634,6 +653,7 @@ void {{ solver.struct_name }}::finish_indices() { {% for nodetype in solver.node_types %} void {{ solver.struct_name }}::Set{{nodetype.__name__}}Num(const size_t num) { + cudaSetDevice(device_id_); if (num > {{num_max_key(nodetype)}}) { throw std::runtime_error(std::to_string(num) + " > {{num_max_key(nodetype)}}"); } @@ -642,6 +662,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}Num(const size_t num) { void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedHost( const {{solver.storage_t}}* const data, const size_t offset, const size_t num) { + cudaSetDevice(device_id_); if (offset + num > {{num_key(nodetype)}}){ throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}"); } @@ -654,6 +675,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedHost( void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedDevice( const {{solver.storage_t}}* const data, const size_t offset, const size_t num) { + cudaSetDevice(device_id_); if (offset + num > {{num_key(nodetype)}}){ throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}"); } @@ -663,6 +685,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedDevice( void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedHost( {{solver.storage_t}}* const data, const size_t offset, const size_t num) { + cudaSetDevice(device_id_); if (offset + num > {{num_key(nodetype)}}){ throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}"); } @@ -675,6 +698,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedHost( void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( {{solver.storage_t}}* const data, const size_t offset, const size_t num) { + cudaSetDevice(device_id_); if (offset + num > {{num_key(nodetype)}}){ throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}"); } @@ -695,6 +719,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( {% if fac.isnodeshared[arg] %} void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost( const unsigned int* const indices, size_t num) { + cudaSetDevice(device_id_); if (num != {{num_key(fac)}}){ throw std::runtime_error( std::to_string(num) @@ -708,7 +733,8 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice( const unsigned int* const indices, size_t num) { indices_valid_ = false; - + cudaSetDevice(device_id_); + if (num != {{num_key(fac)}}){ throw std::runtime_error( std::to_string(num) @@ -737,6 +763,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( const {{solver.storage_t}}* const data, size_t offset, size_t num {% endif %} ) { + cudaSetDevice(device_id_); {% if fac.isconstuniq[arg] %} const size_t offset = 0; const size_t num = 1; @@ -769,6 +796,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( {% elif fac.isconstindexed[arg] %} const {{solver.storage_t}}* const data, size_t offset, size_t num {% endif %} ) { + cudaSetDevice(device_id_); {% if fac.isconstuniq[arg] %} const size_t offset = 0; const size_t num = 1; @@ -791,6 +819,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( {% if fac.isconstshared[arg] %} void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost( const unsigned int* const indices, size_t num) { + cudaSetDevice(device_id_); if (num != {{num_key(fac)}}){ throw std::runtime_error( std::to_string(num) @@ -804,7 +833,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice( const unsigned int* const indices, size_t num) { indices_valid_ = false; - + cudaSetDevice(device_id_); if (num != {{num_key(fac)}}){ throw std::runtime_error( std::to_string(num) @@ -824,6 +853,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost( const unsigned int* const indices, size_t num) { indices_valid_ = false; + cudaSetDevice(device_id_); if (num != {{num_key(fac)}}){ throw std::runtime_error( std::to_string(num) @@ -836,6 +866,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice( void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice( const unsigned int* const indices, size_t num) { indices_valid_ = false; + cudaSetDevice(device_id_); if (num != {{num_key(fac)}}){ throw std::runtime_error( std::to_string(num) diff --git a/symforce/caspar/source/templates/solver.h.jinja b/symforce/caspar/source/templates/solver.h.jinja index 3b4633b4..855a9789 100644 --- a/symforce/caspar/source/templates/solver.h.jinja +++ b/symforce/caspar/source/templates/solver.h.jinja @@ -55,8 +55,9 @@ class {{ solver.struct_name }} { {{ solver.struct_name }}( const SolverParams ¶ms, {% for thing in solver.size_contributors %} - size_t {{num_arg_key(thing)}}{{ ", " if not loop.last else "" }} + size_t {{num_arg_key(thing)}}{{ ", " }} {% endfor %} + int device_id = 0 ); // This class is managing cuda memory and cannot be copied. @@ -210,6 +211,7 @@ class {{ solver.struct_name }} { private: SolverParams<{{solver.linear_t}}> params_; + int device_id_; uint8_t* origin_ptr_; size_t scratch_inout_size_; size_t allocation_size_; diff --git a/symforce/caspar/source/templates/solver_pybinding.h.jinja b/symforce/caspar/source/templates/solver_pybinding.h.jinja index 32b1477f..25c840c1 100644 --- a/symforce/caspar/source/templates/solver_pybinding.h.jinja +++ b/symforce/caspar/source/templates/solver_pybinding.h.jinja @@ -41,12 +41,14 @@ inline void add_solver_pybinding(pybind11::module_ module) { .def(py::init, {% for thing in solver.size_contributors %} size_t{{ ", " if not loop.last else "" }} - {% endfor %}>(), + {% endfor %}, + int>(), py::arg("params"), py::kw_only(), {% for thing in solver.size_contributors %} - py::arg("{{num_arg_key(thing)}}") = 0{{ ", " if not loop.last else "" }} + py::arg("{{num_arg_key(thing)}}") = 0{{ ", " }} {% endfor %} + py::arg("device_id") = 0 ) .def("set_params", &{{solver.struct_name}}::set_params)