From 2b0f856ce56568255f13c2cb491f93d7e754c93b Mon Sep 17 00:00:00 2001 From: tordnat Date: Tue, 28 Apr 2026 15:53:13 +0200 Subject: [PATCH] add parts_to_pascal to fix conversion from non-camel case --- symforce/caspar/code_generation/kernel.py | 6 +++--- symforce/python_util.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/symforce/caspar/code_generation/kernel.py b/symforce/caspar/code_generation/kernel.py index c359ea55..26b039c3 100644 --- a/symforce/caspar/code_generation/kernel.py +++ b/symforce/caspar/code_generation/kernel.py @@ -8,7 +8,7 @@ import re from pathlib import Path -from symforce.python_util import snakecase_to_camelcase +from symforce.python_util import parts_to_pascal from ..code_formulation.dabseg_from_accessors import make_dabseg from ..code_formulation.dabseg_sorter import get_lines @@ -53,13 +53,13 @@ def generate(self, out_dir: Path) -> None: self.registers = [f"r{i}" for i in range(n_registers)] code = env.get_template("kernel.cu.jinja").render( - kernel=self, snake_to_camel=snakecase_to_camelcase + kernel=self, snake_to_camel=parts_to_pascal ) code = EMPTY_BLOCK_PATTERN.sub("", code) write_if_different(code, out_dir.joinpath(f"kernel_{self.name}.cu")) header = env.get_template("kernel.h.jinja").render( - kernel=self, snake_to_camel=snakecase_to_camelcase + kernel=self, snake_to_camel=parts_to_pascal ) write_if_different(header, out_dir.joinpath(f"kernel_{self.name}.h")) diff --git a/symforce/python_util.py b/symforce/python_util.py index 2120e087..dbf16e32 100644 --- a/symforce/python_util.py +++ b/symforce/python_util.py @@ -107,6 +107,16 @@ def snakecase_to_camelcase(s: str) -> str: ) +def parts_to_pascal(s: str) -> str: + """ + Join underscore-separated parts into PascalCase, preserving internal capitalization. + + Unlike snakecase_to_camelcase, this does not lowercase the interior of each part, so + already-PascalCase parts are preserved: "PinholePose_update_p" -> "PinholePoseUpdateP". + """ + return "".join(part[0].upper() + part[1:] for part in s.split("_") if part) + + def snakecase_to_lower_camelcase(s: str) -> str: """ Convert snake_case -> lowerCamelCase