Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions symforce/caspar/code_generation/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import re
from pathlib import Path

from symforce.python_util import snakecase_to_camelcase
from symforce.python_util import parts_to_pascal

from ..code_formulation.dabseg_from_accessors import make_dabseg
from ..code_formulation.dabseg_sorter import get_lines
Expand Down Expand Up @@ -53,13 +53,13 @@ def generate(self, out_dir: Path) -> None:
self.registers = [f"r{i}" for i in range(n_registers)]

code = env.get_template("kernel.cu.jinja").render(
kernel=self, snake_to_camel=snakecase_to_camelcase
kernel=self, snake_to_camel=parts_to_pascal
)
code = EMPTY_BLOCK_PATTERN.sub("", code)
write_if_different(code, out_dir.joinpath(f"kernel_{self.name}.cu"))

header = env.get_template("kernel.h.jinja").render(
kernel=self, snake_to_camel=snakecase_to_camelcase
kernel=self, snake_to_camel=parts_to_pascal
)
write_if_different(header, out_dir.joinpath(f"kernel_{self.name}.h"))

Expand Down
10 changes: 10 additions & 0 deletions symforce/python_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,16 @@ def snakecase_to_camelcase(s: str) -> str:
)


def parts_to_pascal(s: str) -> str:
"""
Join underscore-separated parts into PascalCase, preserving internal capitalization.

Unlike snakecase_to_camelcase, this does not lowercase the interior of each part, so
already-PascalCase parts are preserved: "PinholePose_update_p" -> "PinholePoseUpdateP".
"""
return "".join(part[0].upper() + part[1:] for part in s.split("_") if part)


def snakecase_to_lower_camelcase(s: str) -> str:
"""
Convert snake_case -> lowerCamelCase
Expand Down
Loading