Skip to content
34 changes: 21 additions & 13 deletions symforce/caspar/code_generation/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,23 @@ def add_factor(
insert_sorted_unique(self.exposed_types, factor.arg_types.values())
return func_or_name

def generate(self, out_dir: Path, use_symlinks: bool = False) -> None:
def generate(
self, out_dir: Path, use_symlinks: bool = False, python_bindings: bool = True
) -> None:
out_dir.mkdir(exist_ok=True, parents=True)

for fac in self.factors:
self.kernels.extend(fac.make_kernels())

self.generate_castype_mappings(out_dir)
self.generate_links(out_dir, use_symlinks)
self.generate_castype_mappings(out_dir, python_bindings)
self.generate_links(out_dir, use_symlinks, python_bindings)
if solver := (Solver(self) if self.factors else None):
self.kernels.extend(solver.make_kernels())
solver.generate(out_dir)
self.generate_binding_file(out_dir, solver)
self.generate_stubs(out_dir, solver)
self.generate_buildfiles(out_dir)
solver.generate(out_dir, python_bindings)
if python_bindings:
self.generate_binding_file(out_dir, solver)
self.generate_stubs(out_dir, solver)
self.generate_buildfiles(out_dir, python_bindings)
self.generate_kernels(out_dir)

@staticmethod
Expand Down Expand Up @@ -242,10 +245,14 @@ def generate_kernels(self, out_dir: Path) -> None:
kernel.generate(out_dir)

@staticmethod
def generate_links(out_dir: Path, use_symlinks: bool = True) -> None:
def generate_links(
out_dir: Path, use_symlinks: bool = True, python_bindings: bool = True
) -> None:
for f in Path(caspar.__file__).parent.glob("source/runtime/*"):
if f.is_dir(): # Skip directories like __pycache__
continue
if not python_bindings and "pybind" in f.name:
continue
f_new = out_dir / f.name
if use_symlinks:
if f_new.exists():
Expand All @@ -254,7 +261,7 @@ def generate_links(out_dir: Path, use_symlinks: bool = True) -> None:
else:
copy_if_different(f, f_new)

def generate_castype_mappings(self, out_dir: Path) -> None:
def generate_castype_mappings(self, out_dir: Path, python_bindings: bool = True) -> None:
"""
Generates code to perform mapping between stacked format (array of structs)
and the caspar layout of the corresponding types.
Expand All @@ -271,14 +278,15 @@ def generate_castype_mappings(self, out_dir: Path) -> None:
write_if_different(definition, out_dir.joinpath("caspar_mappings.cu"))
definition = env.get_template("caspar_mappings.h.jinja").render(**kwargs)
write_if_different(definition, out_dir.joinpath("caspar_mappings.h"))
definition = env.get_template("caspar_mappings_pybinding.h.jinja").render(**kwargs)
write_if_different(definition, out_dir.joinpath("caspar_mappings_pybinding.h"))
if python_bindings:
definition = env.get_template("caspar_mappings_pybinding.h.jinja").render(**kwargs)
write_if_different(definition, out_dir.joinpath("caspar_mappings_pybinding.h"))

def generate_binding_file(self, out_dir: Path, solver: Solver | None) -> None:
binding = env.get_template("pybinding.cc.jinja").render(caslib=self, solver=solver)
write_if_different(binding, out_dir.joinpath("pybinding.cc"))

def generate_buildfiles(self, out_dir: Path) -> None:
def generate_buildfiles(self, out_dir: Path, python_bindings: bool = True) -> None:
for template in env.list_templates(filter_func=lambda t: t.startswith("buildfiles")):
content = env.get_template(template).render(caslib=self)
content = env.get_template(template).render(caslib=self, python_bindings=python_bindings)
write_if_different(content, out_dir.joinpath(Path(template).stem))
7 changes: 4 additions & 3 deletions symforce/caspar/code_generation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def set_nonzero(typ: T.Type[sf.Storage], jac: sf.Matrix) -> None:
)
return list(self.kernels.values())

def generate(self, out_dir: Path) -> None:
def generate(self, out_dir: Path, python_bindings: bool = True) -> None:
kwargs = dict(
solver=self,
name_key=name_key,
Expand All @@ -688,5 +688,6 @@ def generate(self, out_dir: Path) -> None:
write_if_different(header, out_dir.joinpath("solver.h"))
definition = env.get_template("solver.cc.jinja").render(**kwargs)
write_if_different(definition, out_dir.joinpath("solver.cc"))
definition = env.get_template("solver_pybinding.h.jinja").render(**kwargs)
write_if_different(definition, out_dir.joinpath("solver_pybinding.h"))
if python_bindings:
definition = env.get_template("solver_pybinding.h.jinja").render(**kwargs)
write_if_different(definition, out_dir.joinpath("solver_pybinding.h"))
51 changes: 30 additions & 21 deletions symforce/caspar/source/templates/buildfiles/CMakeLists.txt.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@ endif()

find_package(CUDAToolkit REQUIRED)

include(FetchContent)
find_package(pybind11 2.13.6 QUIET)
if(NOT pybind11_FOUND)
message(STATUS "pybind11 not found, adding with FetchContent")
FetchContent_Declare(
pybind11
URL https://github.com/pybind/pybind11/archive/v2.13.6.zip
URL_HASH SHA256=d0a116e91f64a4a2d8fb7590c34242df92258a61ec644b79127951e821b47be6
DOWNLOAD_EXTRACT_TIMESTAMP TRUE
)
FetchContent_MakeAvailable(pybind11)
else()
message(STATUS "pybind11 found")
{% if python_bindings %}
option(CASPAR_BUILD_PYTHON_BINDINGS "Build Python bindings via pybind11" ON)

if(CASPAR_BUILD_PYTHON_BINDINGS)
include(FetchContent)
find_package(pybind11 2.13.6 QUIET)
if(NOT pybind11_FOUND)
message(STATUS "pybind11 not found, adding with FetchContent")
FetchContent_Declare(
pybind11
URL https://github.com/pybind/pybind11/archive/v2.13.6.zip
URL_HASH SHA256=d0a116e91f64a4a2d8fb7590c34242df92258a61ec644b79127951e821b47be6
DOWNLOAD_EXTRACT_TIMESTAMP TRUE
)
FetchContent_MakeAvailable(pybind11)
else()
message(STATUS "pybind11 found")
endif()
endif()

{% endif %}
file(GLOB CUDA_SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cu" "*.cc")
list(FILTER CUDA_SOURCES EXCLUDE REGEX "pybind.*")
file(GLOB CUDA_HEADERS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cuh" "*.h")
Expand All @@ -47,13 +53,16 @@ set_target_properties({{caslib.name}}_core PROPERTIES
POSITION_INDEPENDENT_CODE ON
)

file(GLOB CPP_SOURCES "pybind*.cc")
file(GLOB CPP_HEADERS "*.h")
{% if python_bindings %}
if(CASPAR_BUILD_PYTHON_BINDINGS)
file(GLOB CPP_SOURCES "pybind*.cc")

pybind11_add_module({{caslib.name}} ${CPP_SOURCES} ${CPP_HEADERS})
target_link_libraries({{caslib.name}} PRIVATE {{caslib.name}}_core)
pybind11_add_module({{caslib.name}} ${CPP_SOURCES} ${CPP_HEADERS})
target_link_libraries({{caslib.name}} PRIVATE {{caslib.name}}_core)

set_target_properties({{caslib.name}} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
)
set_target_properties({{caslib.name}} PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
)
endif()
{% endif %}
Loading