Skip to content

Commit 24eaf20

Browse files
committed
[SymForce] Add CodegenConfig.custom_preamble
To optionally prepend a string on the generated template. I've also refactored some of the functions that take CodegenConfig arguments, to make them required in more places (which caught a few in consistencies in how they were handled in some functions). Review by commit Topic: sf-preamble Reviewers: bradley,nathan,hayk,chao,danny,william-s,eric GitOrigin-RevId: e115dc05b90268d97c80f8dc9fc0a34b4f3fad6b
1 parent a4d4bbc commit 24eaf20

21 files changed

+109
-34
lines changed

symforce/benchmarks/matrix_multiplication/generate_matrix_multiplication_benchmark.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def compute_A(*symbols: T.List[sf.Symbol]) -> sf.Matrix:
117117

118118
# These files are large enough that autoformatting them is very slow, so just don't do it
119119
config = codegen.CppConfig(
120-
cse_optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)], autoformat=False
120+
cse_optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)],
121+
render_template_config=codegen.RenderTemplateConfig(autoformat=False),
121122
)
122123
config_noinline = dataclasses.replace(config, force_no_inline=True)
123124

@@ -203,6 +204,7 @@ def compute_A(*symbols: T.List[sf.Symbol]) -> sf.Matrix:
203204
n_symbols=N_SYMBOLS,
204205
cant_allocate_on_stack=cant_allocate_on_stack,
205206
),
207+
config=codegen.RenderTemplateConfig(),
206208
output_path=output_dir / f"matrix_multiplication_benchmark_{matrix_name}.cc",
207209
)
208210

symforce/codegen/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from .codegen import InvalidNamespaceError
1515
from .codegen import LinearizationMode
1616
from .codegen_config import CodegenConfig
17+
from .codegen_config import RenderTemplateConfig

symforce/codegen/backends/cpp/cpp_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class CppConfig(CodegenConfig):
2525
line_length: Maximum allowed line length in docstrings; used for formatting docstrings.
2626
use_eigen_types: Use eigen_lcm types for vectors instead of lists
2727
autoformat: Run a code formatter on the generated code
28+
custom_preamble: An optional string to be prepended on the front of the rendered template
2829
cse_optimizations: Optimizations argument to pass to sf.cse
2930
zero_epsilon_behavior: What should codegen do if a default epsilon is not set?
3031
support_complex: Generate code that can work with std::complex or with regular float types

symforce/codegen/backends/python/python_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class PythonConfig(CodegenConfig):
2424
line_length: Maximum allowed line length in docstrings; used for formatting docstrings.
2525
use_eigen_types: Use eigen_lcm types for vectors instead of lists
2626
autoformat: Run a code formatter on the generated code
27+
custom_preamble: An optional string to be prepended on the front of the rendered template
2728
cse_optimizations: Optimizations argument to pass to sf.cse
2829
zero_epsilon_behavior: What should codegen do if a default epsilon is not set?
2930
use_numba: Add the `@numba.njit` decorator to generated functions. This will greatly

symforce/codegen/cam_package_codegen.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,9 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
279279
output_path = cam_package_dir / relative_path.replace(
280280
"CLASS", python_util.camelcase_to_snakecase(cls.__name__)
281281
)
282-
templates.add(template_path, data, output_path=output_path)
282+
templates.add(
283+
template_path, data, config.render_template_config, output_path=output_path
284+
)
283285

284286
# Package init
285287
# NOTE(brad): We already do this in geo_package_codegen.py. We need it there in case we
@@ -293,6 +295,7 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
293295
all_types=list(geo_package_codegen.DEFAULT_GEO_TYPES) + list(DEFAULT_CAM_TYPES),
294296
numeric_epsilon=sf.numeric_epsilon,
295297
),
298+
config=config.render_template_config,
296299
output_path=cam_package_dir / "__init__.py",
297300
)
298301

@@ -306,6 +309,7 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
306309
cam_cal_from_points=cam_cal_from_points,
307310
_DISTORTION_COEFF_VALS=_DISTORTION_COEFF_VALS,
308311
),
312+
config=config.render_template_config,
309313
)
310314

311315
elif isinstance(config, CppConfig):
@@ -335,18 +339,22 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
335339
output_path = cam_package_dir / relative_path.replace(
336340
"CLASS", python_util.camelcase_to_snakecase(cls.__name__)
337341
)
338-
templates.add(template_path, data, output_path=output_path)
342+
templates.add(
343+
template_path, data, config.render_template_config, output_path=output_path
344+
)
339345

340346
# Add Camera and PosedCamera
341347
templates.add(
342348
template_path=Path("cam_package", "camera.h.jinja"),
343349
output_path=cam_package_dir / "camera.h",
344350
data=camera_data(),
351+
config=config.render_template_config,
345352
)
346353
templates.add(
347354
template_path=Path("cam_package") / "posed_camera.h.jinja",
348355
output_path=cam_package_dir / "posed_camera.h",
349356
data=posed_camera_data(),
357+
config=config.render_template_config,
350358
)
351359

352360
# Test example
@@ -381,6 +389,7 @@ def supports_camera_ray_from_pixel(cls: T.Type) -> bool:
381389
if supports_camera_ray_from_pixel(cls)
382390
],
383391
),
392+
config=config.render_template_config,
384393
)
385394
else:
386395
raise NotImplementedError(f'Unknown config type: "{config}"')

symforce/codegen/codegen.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -514,14 +514,15 @@ def generate_function(
514514
# Get templates to render
515515
for source, dest in self.config.templates_to_render(generated_file_name):
516516
templates.add(
517-
source,
518-
template_data,
517+
template_path=source,
518+
data=template_data,
519+
config=self.config.render_template_config,
519520
template_dir=template_dir,
520521
output_path=out_function_dir / dest,
521522
)
522523

523524
# Render
524-
templates.render(autoformat=self.config.autoformat)
525+
templates.render()
525526

526527
lcm_data = codegen_util.generate_lcm_types(
527528
lcm_type_dir=types_codegen_data.lcm_type_dir,

symforce/codegen/codegen_config.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,22 @@ class ZeroEpsilonBehavior(Enum):
3030
DEFAULT_ZERO_EPSILON_BEHAVIOR = ZeroEpsilonBehavior.WARN
3131

3232

33-
# TODO(hayk): Address this type ignore, which comes from having abstract methods on a dataclass.
34-
@dataclass # type: ignore
33+
@dataclass
34+
class RenderTemplateConfig:
35+
"""
36+
Arguments to template_util.render_template
37+
38+
Args:
39+
autoformat: Run a code formatter on the generated code
40+
custom_preamble: An optional string to be prepended on the front of the rendered template
41+
"""
42+
43+
autoformat: bool = True
44+
custom_preamble: str = ""
45+
46+
47+
# TODO(hayk): This type ignore is fixed by https://github.com/python/mypy/pull/13398 in mypy 0.981
48+
@dataclass # type: ignore[misc]
3549
class CodegenConfig:
3650
"""
3751
Base class for backend-specific arguments for code generation.
@@ -41,15 +55,16 @@ class CodegenConfig:
4155
block-style docstrings
4256
line_length: Maximum allowed line length in docstrings; used for formatting docstrings.
4357
use_eigen_types: Use eigen_lcm types for vectors instead of lists
44-
autoformat: Run a code formatter on the generated code
58+
render_template_config: Configuration for template rendering, see RenderTemplateConfig for
59+
more information
4560
cse_optimizations: Optimizations argument to pass to sf.cse
4661
zero_epsilon_behavior: What should codegen do if a default epsilon is not set?
4762
"""
4863

4964
doc_comment_line_prefix: str
5065
line_length: int
5166
use_eigen_types: bool
52-
autoformat: bool = True
67+
render_template_config: RenderTemplateConfig = field(default_factory=RenderTemplateConfig)
5368
cse_optimizations: T.Optional[
5469
T.Union[T.Literal["basic"], T.Sequence[T.Tuple[T.Callable, T.Callable]]]
5570
] = None

symforce/codegen/geo_package_codegen.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,12 +208,15 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
208208
):
209209
template_path = Path(base_dir, relative_path + ".jinja")
210210
output_path = package_dir / relative_path.replace("CLASS", cls.__name__.lower())
211-
templates.add(template_path, data, output_path=output_path)
211+
templates.add(
212+
template_path, data, config.render_template_config, output_path=output_path
213+
)
212214

213215
templates.add(
214216
template_path=Path("ops", "__init__.py.jinja"),
215217
output_path=package_dir / "ops" / "__init__.py",
216218
data={},
219+
config=config.render_template_config,
217220
)
218221

219222
# Package init
@@ -224,6 +227,7 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
224227
all_types=DEFAULT_GEO_TYPES,
225228
numeric_epsilon=sf.numeric_epsilon,
226229
),
230+
config=config.render_template_config,
227231
output_path=package_dir / "__init__.py",
228232
)
229233

@@ -232,6 +236,7 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
232236
templates.add(
233237
template_path=Path("tests", name + ".jinja"),
234238
data=dict(Codegen.common_data(), all_types=DEFAULT_GEO_TYPES),
239+
config=config.render_template_config,
235240
output_path=output_dir / "tests" / name,
236241
)
237242

@@ -261,7 +266,9 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
261266
):
262267
template_path = Path(base_dir, f"{relative_path}.jinja")
263268
output_path = package_dir / relative_path.replace("CLASS", cls.__name__.lower())
264-
templates.add(template_path, data, output_path=output_path)
269+
templates.add(
270+
template_path, data, config.render_template_config, output_path=output_path
271+
)
265272

266273
# Render non geo type specific templates
267274
for template_name in python_util.files_in_dir(
@@ -276,6 +283,7 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
276283
templates.add(
277284
template_path=Path("geo_package", "ops", template_name),
278285
data=dict(Codegen.common_data()),
286+
config=config.render_template_config,
279287
output_path=package_dir / "ops" / template_name[: -len(".jinja")],
280288
)
281289

@@ -298,6 +306,7 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
298306
for scalar in data["scalar_types"]
299307
],
300308
),
309+
config=config.render_template_config,
301310
)
302311

303312
else:
@@ -307,6 +316,7 @@ def generate(config: CodegenConfig, output_dir: Path = None) -> Path:
307316
templates.add(
308317
template_path="symforce_types.lcm.jinja",
309318
data=lcm_types_codegen.lcm_symforce_types_data(),
319+
config=config.render_template_config,
310320
template_dir=template_util.LCM_TEMPLATE_DIR,
311321
output_path=package_dir / ".." / "lcmtypes" / "lcmtypes" / "symforce_types.lcm",
312322
)

symforce/codegen/similarity_index.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def __hash__(self) -> int:
6868
tuple(self.outputs.to_storage()),
6969
self.return_key,
7070
self.sorted_sparse_matrices,
71-
tuple(dataclasses.asdict(self.config).items()),
71+
# Convert to key, value tuples recursively. Unlike astuple, this has field names
72+
dataclasses.asdict(
73+
self.config, dict_factory=T.cast(T.Callable[[T.List], T.Tuple], tuple)
74+
),
7275
)
7376
)

symforce/codegen/sym_util_package_codegen.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def generate(config: codegen.CodegenConfig, output_dir: Path = None) -> Path:
3434
template_path="typedefs.h.jinja",
3535
output_path=package_dir / "typedefs.h",
3636
data={},
37+
config=config.render_template_config,
3738
template_dir=template_dir,
3839
)
3940

@@ -44,6 +45,7 @@ def generate(config: codegen.CodegenConfig, output_dir: Path = None) -> Path:
4445
python_util=python_util,
4546
camera_cal_class_names=cam_package_codegen.camera_cal_class_names(),
4647
),
48+
config=config.render_template_config,
4749
template_dir=template_dir,
4850
)
4951

@@ -52,6 +54,7 @@ def generate(config: codegen.CodegenConfig, output_dir: Path = None) -> Path:
5254
template_path=f"{filename}.jinja",
5355
output_path=package_dir / filename,
5456
data={},
57+
config=config.render_template_config,
5558
template_dir=template_dir,
5659
)
5760
else:

0 commit comments

Comments
 (0)