Skip to content
17 changes: 17 additions & 0 deletions symforce/caspar/source/runtime/pybind_array_tools.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,23 @@ void AssertUint2Vec(const py::object& obj) {
Assert2DNxk(obj, 2);
}

int GetDeviceId(const py::object& obj) {
try {
auto interface = obj.attr("__cuda_array_interface__").cast<py::dict>();
auto data = interface["data"].cast<py::tuple>();
void* ptr = reinterpret_cast<void*>(data[0].cast<size_t>());
cudaPointerAttributes attrs;
cudaError_t err = cudaPointerGetAttributes(&attrs, ptr);
if (err != cudaSuccess) {
cudaGetLastError();
return -1;
}
return attrs.device;
} catch (...) {
return -1; // Fallback if interface or attributes aren't available
}
}

float* AsFloatPtr(const py::object& obj) {
AssertFloatVec(obj);
py::tuple data = GetInterface(obj)["data"].cast<py::tuple>();
Expand Down
2 changes: 2 additions & 0 deletions symforce/caspar/source/runtime/pybind_array_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ void AssertDeviceMemory(const py::object& obj);
void AssertNumRowsEquals(const py::object& obj, size_t n);
void AssertNumColsEquals(const py::object& obj, size_t n);

int GetDeviceId(const py::object& obj);

float* AsFloatPtr(const py::object& obj);
double* AsDoublePtr(const py::object& obj);
int* AsIntPtr(const py::object& obj);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ void add_casmappings_pybindings(pybind11::module_ module) {
throw std::runtime_error(
"The caspar data must have at least as many columns as stacked_data has rows.");
}
cudaSetDevice(GetDeviceId(stacked_data));
{{nodetype.__name__}}StackedToCaspar(
As{{caslib.storage_t.capitalize()}}Ptr(stacked_data), As{{caslib.storage_t.capitalize()}}Ptr(cas_data), cas_stride, 0, num_objects);
});
Expand All @@ -45,7 +46,7 @@ void add_casmappings_pybindings(pybind11::module_ module) {
throw std::runtime_error(
"The caspar data must have at least as many columns as stacked_data has rows.");
}

cudaSetDevice(GetDeviceId(cas_data));
{{nodetype.__name__}}CasparToStacked(
As{{caslib.storage_t.capitalize()}}Ptr(cas_data), As{{caslib.storage_t.capitalize()}}Ptr(stacked_data), cas_stride, 0, num_objects);
});
Expand Down
1 change: 1 addition & 0 deletions symforce/caspar/source/templates/lib.pyi.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class {{solver.struct_name}}:
{% for thing in solver.size_contributors %}
{{num_arg_key(thing)}}: int = 0,
{% endfor %}
device_id: int = 0,
): ...

def set_params(self, params: SolverParams) -> None:
Expand Down
37 changes: 34 additions & 3 deletions symforce/caspar/source/templates/solver.cc.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,12 @@ namespace caspar {
{{ solver.struct_name }}::{{ solver.struct_name }}(
const SolverParams<double> &params,
{% for thing in solver.size_contributors %}
size_t {{num_arg_key(thing)}}{{ ", " if not loop.last else "" }}
size_t {{num_arg_key(thing)}}{{ ", " }}
{% endfor %}
int device_id
)
: params_(params),
device_id_(device_id),
{% for thing in solver.size_contributors %}
{{num_key(thing)}}({{num_arg_key(thing)}}),
{{num_max_key(thing)}}({{num_arg_key(thing)}}){{ ", " if not loop.last else "" }}
Expand All @@ -85,6 +87,20 @@ namespace caspar {
throw std::runtime_error("params.diag_init must be positive");
}
allocation_size_ = get_nbytes();

if (device_id_ < 0) {
throw std::runtime_error("Invalid CUDA device id: " + std::to_string(device_id_));
}
if (device_id_ != 0) {
int deviceCount;
cudaGetDeviceCount(&deviceCount);
if (deviceCount <= device_id_) {
throw std::runtime_error("CUDA detected " + std::to_string(deviceCount) +
" devices, but device " + std::to_string(device_id_) +
" was requested (0-indexed)");
}
}
cudaSetDevice(device_id_);
cudaMalloc(&origin_ptr_, allocation_size_);

size_t offset = 0;
Expand All @@ -97,6 +113,7 @@ namespace caspar {
}

{{ solver.struct_name }}::~{{ solver.struct_name }}(){
cudaSetDevice(device_id_);
cudaFree(origin_ptr_);
}

Expand All @@ -110,6 +127,7 @@ size_t {{ solver.struct_name }}::get_allocation_size(){


SolveResult {{ solver.struct_name }}::solve(bool print_progress, bool verbose_logging) {
cudaSetDevice(device_id_);
SolveResult result;
result.exit_reason = ExitReason::MAX_ITERATIONS;
{{solver.linear_t}} score_best;
Expand All @@ -123,6 +141,7 @@ SolveResult {{ solver.struct_name }}::solve(bool print_progress, bool verbose_lo
std::chrono::time_point<std::chrono::steady_clock> t0 = std::chrono::steady_clock::now();
std::chrono::time_point<std::chrono::steady_clock> t_prev = t0;
score_best = DoResJacFirst();
result.initial_score = score_best;
if (print_progress) {
printf(" score_init: % .6e\n", score_best);
}
Expand Down Expand Up @@ -634,6 +653,7 @@ void {{ solver.struct_name }}::finish_indices() {

{% for nodetype in solver.node_types %}
void {{ solver.struct_name }}::Set{{nodetype.__name__}}Num(const size_t num) {
cudaSetDevice(device_id_);
if (num > {{num_max_key(nodetype)}}) {
throw std::runtime_error(std::to_string(num) + " > {{num_max_key(nodetype)}}");
}
Expand All @@ -642,6 +662,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}Num(const size_t num) {

void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedHost(
const {{solver.storage_t}}* const data, const size_t offset, const size_t num) {
cudaSetDevice(device_id_);
if (offset + num > {{num_key(nodetype)}}){
throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
}
Expand All @@ -654,6 +675,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedHost(

void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedDevice(
const {{solver.storage_t}}* const data, const size_t offset, const size_t num) {
cudaSetDevice(device_id_);
if (offset + num > {{num_key(nodetype)}}){
throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
}
Expand All @@ -663,6 +685,7 @@ void {{ solver.struct_name }}::Set{{nodetype.__name__}}NodesFromStackedDevice(

void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedHost(
{{solver.storage_t}}* const data, const size_t offset, const size_t num) {
cudaSetDevice(device_id_);
if (offset + num > {{num_key(nodetype)}}){
throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
}
Expand All @@ -675,6 +698,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedHost(

void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{{solver.storage_t}}* const data, const size_t offset, const size_t num) {
cudaSetDevice(device_id_);
if (offset + num > {{num_key(nodetype)}}){
throw std::runtime_error(std::to_string(offset + num) + " > {{num_key(nodetype)}}");
}
Expand All @@ -695,6 +719,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{% if fac.isnodeshared[arg] %}
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand All @@ -708,7 +733,8 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;

cudaSetDevice(device_id_);

if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand Down Expand Up @@ -737,6 +763,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
const {{solver.storage_t}}* const data, size_t offset, size_t num
{% endif %}
) {
cudaSetDevice(device_id_);
{% if fac.isconstuniq[arg] %}
const size_t offset = 0;
const size_t num = 1;
Expand Down Expand Up @@ -769,6 +796,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{% elif fac.isconstindexed[arg] %}
const {{solver.storage_t}}* const data, size_t offset, size_t num
{% endif %} ) {
cudaSetDevice(device_id_);
{% if fac.isconstuniq[arg] %}
const size_t offset = 0;
const size_t num = 1;
Expand All @@ -791,6 +819,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{% if fac.isconstshared[arg] %}
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand All @@ -804,7 +833,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;

cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand All @@ -824,6 +853,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;
cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand All @@ -836,6 +866,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;
cudaSetDevice(device_id_);
if (num != {{num_key(fac)}}){
throw std::runtime_error(
std::to_string(num)
Expand Down
4 changes: 3 additions & 1 deletion symforce/caspar/source/templates/solver.h.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ class {{ solver.struct_name }} {
{{ solver.struct_name }}(
const SolverParams<double> &params,
{% for thing in solver.size_contributors %}
size_t {{num_arg_key(thing)}}{{ ", " if not loop.last else "" }}
size_t {{num_arg_key(thing)}}{{ ", " }}
{% endfor %}
int device_id = 0
);

// This class is managing cuda memory and cannot be copied.
Expand Down Expand Up @@ -210,6 +211,7 @@ class {{ solver.struct_name }} {

private:
SolverParams<{{solver.linear_t}}> params_;
int device_id_;
uint8_t* origin_ptr_;
size_t scratch_inout_size_;
size_t allocation_size_;
Expand Down
6 changes: 4 additions & 2 deletions symforce/caspar/source/templates/solver_pybinding.h.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ inline void add_solver_pybinding(pybind11::module_ module) {
.def(py::init<SolverParams<double>,
{% for thing in solver.size_contributors %}
size_t{{ ", " if not loop.last else "" }}
{% endfor %}>(),
{% endfor %},
int>(),
py::arg("params"),
py::kw_only(),
{% for thing in solver.size_contributors %}
py::arg("{{num_arg_key(thing)}}") = 0{{ ", " if not loop.last else "" }}
py::arg("{{num_arg_key(thing)}}") = 0{{ ", " }}
{% endfor %}
py::arg("device_id") = 0
)

.def("set_params", &{{solver.struct_name}}::set_params)
Expand Down
Loading