Skip to content

Commit a5341d8

Browse files
committed
[SymForce] PyTorch backend
Just functions, and just scalar/matrix inputs/outputs for now Reviewers: hayk,bradley,nathan,chao,peter,john-m,cedric Topic: sym-torch Relative: sf-preamble GitOrigin-RevId: 18c8c39b8b1c5f0ccea4ae5de534ab2a3eaeb9cb
1 parent 1c1268a commit a5341d8

File tree

9 files changed

+494
-3
lines changed

9 files changed

+494
-3
lines changed
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
***THIS MODULE IS EXPERIMENTAL***
2+
3+
Backend for PyTorch. This generates Python functions that just call PyTorch ops, i.e. each SymForce op in your expression tree becomes a PyTorch op.
4+
5+
It's possible we could do significantly better than this by generating custom PyTorch ops instead.
6+
7+
This currently only supports vector inputs and outputs, we do not have geo or cam types for PyTorch yet.

symforce/codegen/backends/pytorch/__init__.py

Whitespace-only changes.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
from __future__ import annotations
6+
7+
import sympy
8+
from sympy.printing.codeprinter import CodePrinter
9+
from sympy.printing.pycode import PythonCodePrinter
10+
11+
from symforce import typing as T
12+
13+
_known_functions_torch = {
14+
"Abs": "abs",
15+
"acos": "acos",
16+
"acosh": "acosh",
17+
"asin": "asin",
18+
"asinh": "asinh",
19+
"atan": "atan",
20+
"atan2": "atan2",
21+
"atanh": "atanh",
22+
"ceiling": "ceil",
23+
"cos": "cos",
24+
"cosh": "cosh",
25+
"erf": "erf",
26+
"erfc": "erfc",
27+
"exp": "exp",
28+
"expm1": "expm1",
29+
"floor": "floor",
30+
"hypot": "hypot",
31+
"loggamma": "lgamma",
32+
"log": "log",
33+
"ln": "log",
34+
"log10": "log10",
35+
"log1p": "log1p",
36+
"log2": "log2",
37+
"sin": "sin",
38+
"sinh": "sinh",
39+
"Sqrt": "sqrt",
40+
"tan": "tan",
41+
"tanh": "tanh",
42+
}
43+
44+
_known_constants_math = {
45+
"Exp1": "e",
46+
"Pi": "pi",
47+
"E": "e",
48+
"Infinity": "inf",
49+
"NaN": "nan",
50+
"ComplexInfinity": "nan",
51+
}
52+
53+
54+
def _print_known_const(self: PyTorchCodePrinter, expr: sympy.Expr) -> str:
55+
return f"torch.tensor(math.{_known_constants_math[expr.__class__.__name__]}, **tensor_kwargs)"
56+
57+
58+
def _print_known_func(self: PyTorchCodePrinter, expr: sympy.Expr) -> str:
59+
name = _known_functions_torch[expr.__class__.__name__]
60+
return f"torch.{name}({', '.join(map(self._print, expr.args))})" # pylint: disable=protected-access
61+
62+
63+
class PyTorchCodePrinter(CodePrinter):
64+
"""
65+
Symforce customized code printer for PyTorch. Modifies the Sympy printing
66+
behavior for codegen compatibility and efficiency.
67+
68+
This is more different from PythonCodePrinter than it is similar, so we go mostly from scratch
69+
and call some methods from that printer where desired.
70+
"""
71+
72+
def _format_code(self, lines: T.List[str]) -> T.List[str]:
73+
return lines
74+
75+
def _print_Mod(self, expr: sympy.Mod) -> str:
76+
return f"torch.remainder({self._print(expr.args[0])}, {self._print(expr.args[1])})"
77+
78+
def _print_sign(self, expr: sympy.sign) -> str:
79+
return f"torch.sign({self._print(expr.args[0])})"
80+
81+
def _print_Pow(
82+
self, expr: sympy.Pow, rational: bool = False
83+
) -> str: # pylint: disable=unused-argument
84+
# TODO(aaron): Optimize this?
85+
return f"torch.pow({self._print(expr.base)}, {self._print(expr.exp)})"
86+
87+
def _print_Rational(self, expr: sympy.Rational) -> str:
88+
# This is py3-only, need decimal points if we want py2
89+
return f"torch.tensor({expr.p}/{expr.q}, **tensor_kwargs)"
90+
91+
def _print_frac(self, expr: sympy.frac) -> str:
92+
return self._print_Mod(sympy.Mod(expr.args[0], 1))
93+
94+
def _print_Integer(self, expr: sympy.Integer) -> str:
95+
"""
96+
Customizations:
97+
* Cast all integers to Tensor
98+
"""
99+
return f"torch.tensor({expr.p}, **tensor_kwargs)"
100+
101+
def _print_NumberSymbol(self, expr: sympy.Expr) -> str:
102+
"""
103+
Customizations:
104+
* Cast all NumberSymbols to Tensor
105+
"""
106+
return f"torch.tensor({super()._print_NumberSymbol(expr)}, **tensor_kwargs)"
107+
108+
def _print_Zero(self, expr: sympy.Expr) -> str:
109+
"""
110+
Customizations:
111+
* Cast Zero to Tensor
112+
"""
113+
return "torch.tensor(0, **tensor_kwargs)"
114+
115+
def _print_Symbol(self, expr: sympy.Symbol) -> str:
116+
name = super()._print_Symbol(expr)
117+
118+
if name in PythonCodePrinter.reserved_words:
119+
raise ValueError(
120+
f'This expression includes the symbol "{name}" which is a reserved keyword in Python.'
121+
)
122+
123+
return name
124+
125+
def _print_Max(self, expr: sympy.Max) -> str:
126+
if len(expr.args) == 1:
127+
return self._print(expr.args[0])
128+
else:
129+
from sympy.functions.elementary.miscellaneous import Max
130+
131+
return f"torch.maximum({self._print(expr.args[0])}, {self._print(Max(*expr.args[1:]))})"
132+
133+
def _print_Min(self, expr: sympy.Min) -> str:
134+
if len(expr.args) == 1:
135+
return self._print(expr.args[0])
136+
else:
137+
from sympy.functions.elementary.miscellaneous import Min
138+
139+
return f"torch.minimum({self._print(expr.args[0])}, {self._print(Min(*expr.args[1:]))})"
140+
141+
# NOTE(brad): We type ignore the signature because mypy complains that it
142+
# does not match that of the sympy base class CodePrinter. This is because the base class
143+
# defines _print_Heaviside with: _print_Heaviside = None (see
144+
# https://github.com/sympy/sympy/blob/95f0228c033d27731f8707cdbb5bb672e500847d/sympy/printing/codeprinter.py#L446
145+
# ).
146+
# Despite this, our signature here matches the signatures of the sympy defined subclasses
147+
# of CodePrinter. I don't know of any other way to resolve this issue other than to
148+
# to type ignore.
149+
def _print_Heaviside(self, expr: "sympy.Heaviside") -> str: # type: ignore[override]
150+
return f"torch.heaviside({self._print(expr)}, values=torch.tensor(1.0, **tensor_kwargs))"
151+
152+
153+
for k in _known_functions_torch:
154+
setattr(PyTorchCodePrinter, f"_print_{k}", _print_known_func)
155+
156+
for k in _known_constants_math:
157+
setattr(PyTorchCodePrinter, f"_print_{k}", _print_known_const)
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ----------------------------------------------------------------------------
5+
from dataclasses import dataclass
6+
from pathlib import Path
7+
8+
from sympy.printing.codeprinter import CodePrinter
9+
10+
from symforce import typing as T
11+
from symforce.codegen.backends.pytorch import pytorch_code_printer
12+
from symforce.codegen.codegen_config import CodegenConfig
13+
14+
CURRENT_DIR = Path(__file__).parent
15+
16+
17+
@dataclass
18+
class PyTorchConfig(CodegenConfig):
19+
"""
20+
Code generation config for the PyTorch backend.
21+
22+
Args:
23+
doc_comment_line_prefix: Prefix applied to each line in a docstring
24+
line_length: Maximum allowed line length in docstrings; used for formatting docstrings.
25+
use_eigen_types: Use eigen_lcm types for vectors instead of lists
26+
autoformat: Run a code formatter on the generated code
27+
custom_preamble: An optional string to be prepended on the front of the rendered template
28+
cse_optimizations: Optimizations argument to pass to sf.cse
29+
zero_epsilon_behavior: What should codegen do if a default epsilon is not set?
30+
"""
31+
32+
doc_comment_line_prefix: str = ""
33+
line_length: int = 100
34+
use_eigen_types: bool = False
35+
36+
@classmethod
37+
def backend_name(cls) -> str:
38+
return "pytorch"
39+
40+
@classmethod
41+
def template_dir(cls) -> Path:
42+
return CURRENT_DIR / "templates"
43+
44+
def templates_to_render(self, generated_file_name: str) -> T.List[T.Tuple[str, str]]:
45+
return [
46+
("function/FUNCTION.py.jinja", f"{generated_file_name}.py"),
47+
("function/__init__.py.jinja", "__init__.py"),
48+
]
49+
50+
def printer(self) -> CodePrinter:
51+
return pytorch_code_printer.PyTorchCodePrinter()
52+
53+
def format_matrix_accessor(self, key: str, i: int, j: int, *, shape: T.Tuple[int, int]) -> str:
54+
PyTorchConfig._assert_indices_in_bounds(i, j, shape)
55+
if (shape[0] == 1) ^ (shape[1] == 1):
56+
return f"{key}[..., {max(i, j)}]"
57+
elif shape[0] == 1 and shape[1] == 1:
58+
return key
59+
else:
60+
return f"{key}[..., {i}, {j}]"
61+
62+
@staticmethod
63+
def format_eigen_lcm_accessor(key: str, i: int) -> str:
64+
"""
65+
Format accessor for eigen_lcm types.
66+
"""
67+
raise NotImplementedError("Can't pass eigen_lcm types to PyTorch functions")
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ---------------------------------------------------------------------------- #}
5+
{%- import "../util/util.jinja" as util with context -%}
6+
7+
# pylint: disable=too-many-locals,too-many-lines,too-many-statements,unused-argument
8+
9+
import math # pylint: disable=unused-import
10+
import typing as T
11+
12+
import torch
13+
14+
15+
class TensorKwargs(T.TypedDict):
16+
"""
17+
TypedDict representing args that will be passed to any torch.tensor calls
18+
"""
19+
device: torch.device
20+
dtype: torch.dtype
21+
22+
23+
{{ util.function_declaration(spec) }}
24+
{% if spec.docstring %}
25+
{{ util.print_docstring(spec.docstring) | indent(4) }}
26+
{% endif %}
27+
28+
{{ util.expr_code(spec) }}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
{# ----------------------------------------------------------------------------
2+
# SymForce - Copyright 2022, Skydio, Inc.
3+
# This source code is under the Apache 2.0 license found in the LICENSE file.
4+
# ---------------------------------------------------------------------------- #}
5+
from .{{ spec.name }} import {{ spec.name }}

0 commit comments

Comments
 (0)