Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions symforce/caspar/code_generation/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def generate(self, out_dir: Path) -> None:
self.registers = [f"r{i}" for i in range(n_registers)]

code = env.get_template("kernel.cu.jinja").render(
kernel=self, snake_to_camel=parts_to_pascal
kernel=self, parts_to_pascal=parts_to_pascal
)
code = EMPTY_BLOCK_PATTERN.sub("", code)
write_if_different(code, out_dir.joinpath(f"kernel_{self.name}.cu"))

header = env.get_template("kernel.h.jinja").render(
kernel=self, snake_to_camel=parts_to_pascal
kernel=self, parts_to_pascal=parts_to_pascal
)
write_if_different(header, out_dir.joinpath(f"kernel_{self.name}.h"))

Expand Down
5 changes: 4 additions & 1 deletion symforce/caspar/code_generation/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from symforce.caspar.memory.dtype import DType
from symforce.ops import LieGroupOps as Ops
from symforce.python_util import camelcase_to_snakecase
from symforce.python_util import parts_to_pascal

from ..code_generation.factor import Factor
from ..code_generation.kernel import Kernel
Expand Down Expand Up @@ -241,7 +242,9 @@ def generate_castype_mappings(self, out_dir: Path) -> None:
write_if_different(definition, out_dir.joinpath("caspar_mappings_pybinding.h"))

def generate_binding_file(self, out_dir: Path, solver: Solver | None) -> None:
binding = env.get_template("pybinding.cc.jinja").render(caslib=self, solver=solver)
binding = env.get_template("pybinding.cc.jinja").render(
caslib=self, solver=solver, parts_to_pascal=parts_to_pascal
)
write_if_different(binding, out_dir.joinpath("pybinding.cc"))

def generate_buildfiles(self, out_dir: Path) -> None:
Expand Down
4 changes: 2 additions & 2 deletions symforce/caspar/code_generation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from symforce import typing as T
from symforce.codegen.codegen import WARNING_MESSAGE
from symforce.ops import LieGroupOps as Ops
from symforce.python_util import snakecase_to_camelcase
from symforce.python_util import parts_to_pascal

from ..code_generation.factor import Factor
from ..code_generation.factor import dyn_part
Expand Down Expand Up @@ -681,7 +681,7 @@ def generate(self, out_dir: Path) -> None:
num_blocks_key=num_blocks_key,
num_max_key=num_max_key,
num_arg_key=num_arg_key,
snake_to_camel=snakecase_to_camelcase,
parts_to_pascal=parts_to_pascal,
Ops=Ops,
)
header = env.get_template("solver.h.jinja").render(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion symforce/caspar/examples/bal_deployed/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Pixel(sf.V2):
def fac_reprojection(
cam: T.Annotated[Cam, mem.TunableShared],
point: T.Annotated[Point, mem.TunableShared],
pixel: T.Annotated[Pixel, mem.Constant],
pixel: T.Annotated[Pixel, mem.ConstantSequential],
) -> sf.V2:
cam_T_world = cam.pose
intrinsics = cam.calib
Expand Down
4 changes: 1 addition & 3 deletions symforce/caspar/examples/bal_deployed/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

import numpy as np

from symforce.caspar.examples.bal_deployed.generated import ( # type: ignore[import-not-found]
caspar_lib,
)
from generated import caspar_lib # type: ignore[import-not-found]

try:
npz_path = Path(sys.argv[1])
Expand Down
14 changes: 6 additions & 8 deletions symforce/caspar/examples/kernel_example/gen_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,21 @@ def example_kernel(
caslib.generate(out_dir)
caslib.compile(out_dir)

# Can also be imported using: lib = caslib.import_lib(out_dir)
from symforce.caspar.examples.kernel_example.generated import ( # type: ignore[import-not-found]
caspar_lib as lib,
)
# lib = caslib.import_lib(out_dir)
from generated import caspar_lib as lib # type: ignore[import-not-found, unused-ignore]

N = 100
arg0_stacked = torch.rand(N, sf.V3.storage_dim(), device="cuda")
arg0_caspar = torch.empty(mem.caspar_size(sf.V3.storage_dim()), N, device="cuda")
lib.Matrix31_stacked_to_caspar(arg0_stacked, arg0_caspar)
lib.matrix31_stacked_to_caspar(arg0_stacked, arg0_caspar)

arg0_indices = torch.randint(0, N, (N,), device="cuda", dtype=torch.int32)
arg0_indices_shared = torch.empty(N, 2, device="cuda", dtype=torch.int32)
lib.shared_indices(arg0_indices, arg0_indices_shared)

arg1_stacked = torch.rand(1, 6, device="cuda")
arg1_caspar = torch.empty(mem.caspar_size(6), 1, device="cuda")
lib.Matrix61_stacked_to_caspar(arg1_stacked, arg1_caspar)
lib.matrix61_stacked_to_caspar(arg1_stacked, arg1_caspar)

BLOCK_SIZE = 1024
OUT0_IDX_MAX = 10
Expand All @@ -87,8 +85,8 @@ def example_kernel(
out0_sharedsum = torch.zeros(OUT0_IDX_MAX, 2, device="cuda")
out1_indexed = torch.empty(N, 1, device="cuda")

lib.Matrix21_caspar_to_stacked(out0_caspar, out0_sharedsum)
lib.Symbol_caspar_to_stacked(out1_caspar, out1_indexed)
lib.matrix21_caspar_to_stacked(out0_caspar, out0_sharedsum)
lib.symbol_caspar_to_stacked(out1_caspar, out1_indexed)

# Check the results
sincos = 2 * torch.stack([torch.sin(arg0_stacked[:, 0]), torch.cos(arg0_stacked[:, 0])], dim=1)
Expand Down
15 changes: 8 additions & 7 deletions symforce/caspar/examples/multiple_factors/gen_and_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def to_tensor(storage: sf.Storage) -> torch.Tensor:
caslib.generate(out_dir) # Can be commented out after the first run to avoid regenerating (slow)
caslib.compile(out_dir) # Can be commented out after the first run to avoid recompiling (slow)


# Can also be imported using:
# lib = caslib.import_lib(out_dir)
from generated import caspar_lib as lib # type: ignore[import-not-found, unused-ignore]
Expand Down Expand Up @@ -181,15 +182,15 @@ def to_tensor(storage: sf.Storage) -> torch.Tensor:

# Map the generated Caspar data to regular array of structs (AOS) format.
pose_stacked = torch.empty(N_POSE, mem.stacked_size(Pose))
lib.Pose_caspar_to_stacked(pose_caspar, pose_stacked)
lib.pose_caspar_to_stacked(pose_caspar, pose_stacked)
landmarks_stacked = torch.empty(N_LANDMARK, mem.stacked_size(Landmark))
lib.Landmark_caspar_to_stacked(landmarks_caspar, landmarks_stacked)
lib.landmark_caspar_to_stacked(landmarks_caspar, landmarks_stacked)
odometry_stacked = torch.empty(N_POSE - 1, mem.stacked_size(OdometryMeasurement))
lib.OdometryMeasurement_caspar_to_stacked(odometry_caspar, odometry_stacked)
lib.odometry_measurement_caspar_to_stacked(odometry_caspar, odometry_stacked)
pos_meas_stacked = torch.empty(N_GNSS, mem.stacked_size(PositionMeasurement))
lib.posMeasurement_caspar_to_stacked(pos_meas_caspar, pos_meas_stacked)
lib.position_measurement_caspar_to_stacked(pos_meas_caspar, pos_meas_stacked)
landmark_meas_stacked = torch.empty(N_LANDMARK_ERROR, mem.stacked_size(LandmarkMeasurement))
lib.LandmarkMeasurement_caspar_to_stacked(landmark_meas_caspar, landmark_meas_stacked)
lib.landmark_measurement_caspar_to_stacked(landmark_meas_caspar, landmark_meas_stacked)


# Add some noise to the data.
Expand All @@ -210,15 +211,15 @@ def to_tensor(storage: sf.Storage) -> torch.Tensor:
params,
Pose_num_max=N_POSE,
Landmark_num_max=N_LANDMARK,
posSensorOffset_num_max=1,
PositionSensorOffset_num_max=1,
LandmarkSensorOffset_num_max=1,
pos_error_num_max=N_GNSS,
landmark_error_num_max=N_LANDMARK_ERROR,
odometry_error_num_max=N_POSE - 1,
)


solver.set_posSensorOffset_nodes_from_stacked_device(pos_sensor_offset)
solver.set_PositionSensorOffset_nodes_from_stacked_device(pos_sensor_offset)
solver.set_LandmarkSensorOffset_nodes_from_stacked_device(landmark_sensor_offset)

# To demonstrade how to update the solver dynamically we start by loading and optimizing only half the problem.
Expand Down
6 changes: 3 additions & 3 deletions symforce/caspar/source/templates/kernel.cu.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace caspar {

__global__ void
__launch_bounds__({{kernel.block_size}}, 1)
{{snake_to_camel(kernel.name)}}Kernel(
{{parts_to_pascal(kernel.name)}}Kernel(
{% for accessor in kernel.accessors %}
{% for name, typ in accessor.kernel_sig.items() %}
{{typ}} {{name}},
Expand Down Expand Up @@ -57,7 +57,7 @@ __launch_bounds__({{kernel.block_size}}, 1)
}


void {{snake_to_camel(kernel.name)}} (
void {{parts_to_pascal(kernel.name)}} (
{% for accessor in kernel.accessors %}
{% for name, typ in accessor.kernel_sig.items() %}
{{typ}} {{name}},
Expand All @@ -79,7 +79,7 @@ void {{snake_to_camel(kernel.name)}} (
{{accessor.pre_kernel_code()}}
{% endif %}
{% endfor %}
{{snake_to_camel(kernel.name)}}Kernel<<<n_blocks, {{kernel.block_size}}>>>(
{{parts_to_pascal(kernel.name)}}Kernel<<<n_blocks, {{kernel.block_size}}>>>(
{% for accessor in kernel.accessors %}
{% for name in accessor.kernel_sig %}
{{name}},
Expand Down
2 changes: 1 addition & 1 deletion symforce/caspar/source/templates/kernel.h.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

namespace caspar {

void {{snake_to_camel(kernel.name)}} (
void {{parts_to_pascal(kernel.name)}} (
{% for desc in kernel.accessors %}
{% for name, typ in desc.kernel_sig.items() %}
{{ typ }} {{ name }},
Expand Down
6 changes: 3 additions & 3 deletions symforce/caspar/source/templates/pybinding.cc.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace caspar;

{% for kernel in caslib.kernels %}
{% if kernel.expose_to_python %}
void {{snake_to_camel(kernel.name)}}Pybinding(
void {{parts_to_pascal(kernel.name)}}Pybinding(
{% for accessor in kernel.accessors %}
{% for name, typ in accessor.py_sig.items() %}
{{typ}} {{name}},
Expand All @@ -38,7 +38,7 @@ void {{snake_to_camel(kernel.name)}}Pybinding(
{% endfor %}
{% endfor %}

{{snake_to_camel(kernel.name)}}(
{{parts_to_pascal(kernel.name)}}(
{% for accessor in kernel.accessors %}
{% for arg in accessor.py_args %}
{{arg}},
Expand All @@ -59,7 +59,7 @@ PYBIND11_MODULE({{caslib.name}}, module) {
module.def("shared_indices", &caspar::shared_indices_pybinding);
{% for kernel in caslib.kernels %}
{% if kernel.expose_to_python %}
module.def("{{kernel.name}}", &{{snake_to_camel(kernel.name)}}_pybinding);
module.def("{{kernel.name}}", &{{parts_to_pascal(kernel.name)}}Pybinding);
{% endif %}
{% endfor %}

Expand Down
28 changes: 14 additions & 14 deletions symforce/caspar/source/templates/solver.cc.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ SolveResult {{ solver.struct_name }}::solve(bool print_progress, bool verbose_lo

{% for fac in solver.factors %}

{{snake_to_camel(fac.name)}}ResJac{{isfirst}}(
{{parts_to_pascal(fac.name)}}ResJac{{isfirst}}(
{% for arg, typ in fac.arg_types.items() %}
{% if fac.isnode[arg] %}
{% if fac.isnodepair[arg] %}
Expand Down Expand Up @@ -406,7 +406,7 @@ void {{ solver.struct_name }}::DoUpdateMp() {
void {{ solver.struct_name }}::DoJtjpDirect() {
{% for fac in solver.factors %}
{% if not fac.solved_by_preconditioner %}
{{snake_to_camel(fac.name)}}JtjnjtrDirect(
{{parts_to_pascal(fac.name)}}JtjnjtrDirect(
{% for arg, nodetype in fac.node_arg_types.items() %}
{% if fac.isnodepair[arg] %}
{{node_key(nodetype, "p")}},
Expand Down Expand Up @@ -542,7 +542,7 @@ void {{ solver.struct_name }}::DoUpdateR() {
{% endfor %}
Zero({{solver_key("res_tot")}}, {{solver_key("res_tot")}}+1);
{% for fac in solver.factors %}
{{snake_to_camel(fac.name)}}Score(
{{parts_to_pascal(fac.name)}}Score(
{% for arg, typ in fac.arg_types.items() %}
{% if fac.isnode[arg] %}
{% if fac.isnodepair[arg] %}
Expand Down Expand Up @@ -685,15 +685,15 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{% endfor%}

{% for fac in solver.factors %}
void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}Num(const size_t num) {
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}Num(const size_t num) {
if (num > {{num_max_key(fac)}}){
throw std::runtime_error(std::to_string(num) + " > {{num_max_key(fac)}}");
}
{{num_key(fac)}} = num;
}
{% for arg, argtype in fac.node_arg_types.items() %}
{% if fac.isnodeshared[arg] %}
void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}IndicesFromHost(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
if (num != {{num_key(fac)}}){
throw std::runtime_error(
Expand All @@ -702,10 +702,10 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
}
cudaMemcpy((unsigned int*)marker__scratch_inout_, indices, num * sizeof(unsigned int),
cudaMemcpyHostToDevice);
Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}IndicesFromDevice((unsigned int*)marker__scratch_inout_, num);
Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice((unsigned int*)marker__scratch_inout_, num);
}

void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}IndicesFromDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;

Expand All @@ -726,7 +726,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
{% endif %}
{% endfor %}
{% for arg, argtype in fac.const_arg_types.items() %}
void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}DataFromStackedHost(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}DataFromStackedHost(
{% if fac.isconstseq[arg] %}
const {{solver.storage_t}}* const data, size_t offset, size_t num
{% elif fac.isconstuniq[arg] %}
Expand Down Expand Up @@ -759,7 +759,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
offset, num);
}

void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}DataFromStackedDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}DataFromStackedDevice(
{% if fac.isconstseq[arg] %}
const {{solver.storage_t}}* const data, size_t offset, size_t num
{% elif fac.isconstuniq[arg] %}
Expand Down Expand Up @@ -789,7 +789,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
offset, num);
}
{% if fac.isconstshared[arg] %}
void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}IndicesFromHost(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
if (num != {{num_key(fac)}}){
throw std::runtime_error(
Expand All @@ -798,10 +798,10 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
}
cudaMemcpy((unsigned int*)marker__scratch_inout_, indices, num * sizeof(unsigned int),
cudaMemcpyHostToDevice);
Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}IndicesFromDevice((unsigned int*)marker__scratch_inout_, num);
Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice((unsigned int*)marker__scratch_inout_, num);
}

void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}IndicesFromDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;

Expand All @@ -821,7 +821,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(

}
{% elif fac.isconstindexed[arg] %}
void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}IndicesFromHost(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromHost(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;
if (num != {{num_key(fac)}}){
Expand All @@ -833,7 +833,7 @@ void {{solver.struct_name}}::Get{{nodetype.__name__}}NodesToStackedDevice(
cudaMemcpyHostToDevice);
}

void {{ solver.struct_name }}::Set{{snake_to_camel(fac.name)}}{{snake_to_camel(arg)}}IndicesFromDevice(
void {{ solver.struct_name }}::Set{{parts_to_pascal(fac.name)}}{{parts_to_pascal(arg)}}IndicesFromDevice(
const unsigned int* const indices, size_t num) {
indices_valid_ = false;
if (num != {{num_key(fac)}}){
Expand Down
Loading
Loading