Skip to content

Commit eab63ed

Browse files
authored
Jittable transforms (#1575)
* [WIP] jittable transforms * add licence to new test file * turn BijectorConstraint into pytree * test flattening/unflattening of parametrized constraints * cosmetic edits * fix typo * implement tree_flatten/unflatten for transforms * attempt to avoid confusing black * add (un)flattening meths for BijectorTransform * fixup! implement tree_flatten/unflatten for transforms * test vmapping over transforms/constraints * Make constraints `__eq__` checks robust to arbitrary inputs * make transforms equality check robust to arbitrary inputs * test constraints and transforms equality checks
1 parent e230805 commit eab63ed

File tree

6 files changed

+656
-23
lines changed

6 files changed

+656
-23
lines changed

numpyro/contrib/tfp/distributions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,13 @@ def __call__(self, x):
6666
def codomain(self):
6767
return _get_codomain(self.bijector)
6868

69+
def tree_flatten(self):
70+
return self.bijector, ()
71+
72+
@classmethod
73+
def tree_unflatten(cls, _, bijector):
74+
return cls(bijector)
75+
6976

7077
class BijectorTransform(Transform):
7178
"""
@@ -106,6 +113,13 @@ def inverse_shape(self, shape):
106113
batch_shape = shape[: len(shape) - len(out_event_shape)]
107114
return batch_shape + in_shape
108115

116+
def tree_flatten(self):
117+
return self.bijector, ()
118+
119+
@classmethod
120+
def tree_unflatten(cls, _, bijector):
121+
return cls(bijector)
122+
109123

110124
@biject_to.register(BijectorConstraint)
111125
def _transform_to_bijector_constraint(constraint):

numpyro/distributions/constraints.py

Lines changed: 114 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@
6262
import numpy as np
6363

6464
import jax.numpy
65+
import jax.numpy as jnp
66+
from jax.tree_util import register_pytree_node
6567

6668

6769
class Constraint(object):
@@ -75,6 +77,10 @@ class Constraint(object):
7577
is_discrete = False
7678
event_dim = 0
7779

80+
def __init_subclass__(cls, **kwargs):
81+
super().__init_subclass__(**kwargs)
82+
register_pytree_node(cls, cls.tree_flatten, cls.tree_unflatten)
83+
7884
def __call__(self, x):
7985
raise NotImplementedError
8086

@@ -94,8 +100,24 @@ def feasible_like(self, prototype):
94100
"""
95101
raise NotImplementedError
96102

103+
@classmethod
104+
def tree_unflatten(cls, aux_data, params):
105+
params_keys, aux_data = aux_data
106+
self = cls.__new__(cls)
107+
for k, v in zip(params_keys, params):
108+
setattr(self, k, v)
109+
110+
for k, v in aux_data.items():
111+
setattr(self, k, v)
112+
return self
113+
114+
115+
class ParameterFreeConstraint(Constraint):
116+
def tree_flatten(self):
117+
return (), ((), dict())
118+
97119

98-
class _SingletonConstraint(Constraint):
120+
class _SingletonConstraint(ParameterFreeConstraint):
99121
"""
100122
A constraint type which has only one canonical instance, like constraints.real,
101123
and unlike constraints.interval.
@@ -202,8 +224,23 @@ def __call__(self, x=None, *, is_discrete=NotImplemented, event_dim=NotImplement
202224
event_dim = self._event_dim
203225
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
204226

227+
def __eq__(self, other):
228+
return (
229+
type(self) is type(other)
230+
and self._is_discrete == other._is_discrete
231+
and self._event_dim == other._event_dim
232+
)
233+
234+
def tree_flatten(self):
235+
return (), (
236+
(),
237+
dict(_is_discrete=self._is_discrete, _event_dim=self._event_dim),
238+
)
239+
205240

206241
class dependent_property(property, _Dependent):
242+
# XXX: this should not need to be pytree-able since it simply wraps a method
243+
# and thus is automatically present once the method's object is created
207244
def __init__(
208245
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
209246
):
@@ -243,8 +280,16 @@ def __repr__(self):
243280
def feasible_like(self, prototype):
244281
return jax.numpy.broadcast_to(self.lower_bound + 1, jax.numpy.shape(prototype))
245282

283+
def tree_flatten(self):
284+
return (self.lower_bound,), (("lower_bound",), dict())
285+
286+
def __eq__(self, other):
287+
if not isinstance(other, _GreaterThan):
288+
return False
289+
return jnp.array_equal(self.lower_bound, other.lower_bound)
246290

247-
class _Positive(_GreaterThan, _SingletonConstraint):
291+
292+
class _Positive(_SingletonConstraint, _GreaterThan):
248293
def __init__(self):
249294
super().__init__(0.0)
250295

@@ -301,6 +346,20 @@ def __repr__(self):
301346
def feasible_like(self, prototype):
302347
return self.base_constraint.feasible_like(prototype)
303348

349+
def tree_flatten(self):
350+
return (self.base_constraint,), (
351+
("base_constraint",),
352+
{"reinterpreted_batch_ndims": self.reinterpreted_batch_ndims},
353+
)
354+
355+
def __eq__(self, other):
356+
if not isinstance(other, _IndependentConstraint):
357+
return False
358+
359+
return (self.base_constraint == other.base_constraint) & (
360+
self.reinterpreted_batch_ndims == other.reinterpreted_batch_ndims
361+
)
362+
304363

305364
class _RealVector(_IndependentConstraint, _SingletonConstraint):
306365
def __init__(self):
@@ -327,6 +386,14 @@ def __repr__(self):
327386
def feasible_like(self, prototype):
328387
return jax.numpy.broadcast_to(self.upper_bound - 1, jax.numpy.shape(prototype))
329388

389+
def tree_flatten(self):
390+
return (self.upper_bound,), (("upper_bound",), dict())
391+
392+
def __eq__(self, other):
393+
if not isinstance(other, _LessThan):
394+
return False
395+
return jnp.array_equal(self.upper_bound, other.upper_bound)
396+
330397

331398
class _IntegerInterval(Constraint):
332399
is_discrete = True
@@ -348,6 +415,20 @@ def __repr__(self):
348415
def feasible_like(self, prototype):
349416
return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype))
350417

418+
def tree_flatten(self):
419+
return (self.lower_bound, self.upper_bound), (
420+
("lower_bound", "upper_bound"),
421+
dict(),
422+
)
423+
424+
def __eq__(self, other):
425+
if not isinstance(other, _IntegerInterval):
426+
return False
427+
428+
return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal(
429+
self.upper_bound, other.upper_bound
430+
)
431+
351432

352433
class _IntegerGreaterThan(Constraint):
353434
is_discrete = True
@@ -366,13 +447,21 @@ def __repr__(self):
366447
def feasible_like(self, prototype):
367448
return jax.numpy.broadcast_to(self.lower_bound, jax.numpy.shape(prototype))
368449

450+
def tree_flatten(self):
451+
return (self.lower_bound,), (("lower_bound",), dict())
369452

370-
class _IntegerPositive(_IntegerGreaterThan, _SingletonConstraint):
453+
def __eq__(self, other):
454+
if not isinstance(other, _IntegerGreaterThan):
455+
return False
456+
return jnp.array_equal(self.lower_bound, other.lower_bound)
457+
458+
459+
class _IntegerPositive(_SingletonConstraint, _IntegerGreaterThan):
371460
def __init__(self):
372461
super().__init__(1)
373462

374463

375-
class _IntegerNonnegative(_IntegerGreaterThan, _SingletonConstraint):
464+
class _IntegerNonnegative(_SingletonConstraint, _IntegerGreaterThan):
376465
def __init__(self):
377466
super().__init__(0)
378467

@@ -398,19 +487,25 @@ def feasible_like(self, prototype):
398487
)
399488

400489
def __eq__(self, other):
401-
return (
402-
isinstance(other, _Interval)
403-
and self.lower_bound == other.lower_bound
404-
and self.upper_bound == other.upper_bound
490+
if not isinstance(other, _Interval):
491+
return False
492+
return jnp.array_equal(self.lower_bound, other.lower_bound) & jnp.array_equal(
493+
self.upper_bound, other.upper_bound
405494
)
406495

496+
def tree_flatten(self):
497+
return (self.lower_bound, self.upper_bound), (
498+
("lower_bound", "upper_bound"),
499+
dict(),
500+
)
407501

408-
class _Circular(_Interval, _SingletonConstraint):
502+
503+
class _Circular(_SingletonConstraint, _Interval):
409504
def __init__(self):
410505
super().__init__(-math.pi, math.pi)
411506

412507

413-
class _UnitInterval(_Interval, _SingletonConstraint):
508+
class _UnitInterval(_SingletonConstraint, _Interval):
414509
def __init__(self):
415510
super().__init__(0.0, 1.0)
416511

@@ -462,6 +557,14 @@ def feasible_like(self, prototype):
462557
value = jax.numpy.pad(jax.numpy.expand_dims(self.upper_bound, -1), pad_width)
463558
return jax.numpy.broadcast_to(value, prototype.shape)
464559

560+
def tree_flatten(self):
561+
return (self.upper_bound,), (("upper_bound",), dict())
562+
563+
def __eq__(self, other):
564+
if not isinstance(other, _Multinomial):
565+
return False
566+
return jnp.array_equal(self.upper_bound, other.upper_bound)
567+
465568

466569
class _L1Ball(_SingletonConstraint):
467570
"""
@@ -546,7 +649,7 @@ def feasible_like(self, prototype):
546649
return jax.numpy.full_like(prototype, 1 / prototype.shape[-1])
547650

548651

549-
class _SoftplusPositive(_GreaterThan, _SingletonConstraint):
652+
class _SoftplusPositive(_SingletonConstraint, _GreaterThan):
550653
def __init__(self):
551654
super().__init__(lower_bound=0.0)
552655

numpyro/distributions/flows.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,21 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
9393
log_scale = intermediates
9494
return log_scale.sum(-1)
9595

96+
def tree_flatten(self):
97+
return (self.log_scale_min_clip, self.log_scale_max_clip), (
98+
("log_scale_min_clip", "log_scale_max_clip"),
99+
{"arn": self.arn},
100+
)
101+
102+
def __eq__(self, other):
103+
if not isinstance(other, InverseAutoregressiveTransform):
104+
return False
105+
return (
106+
(self.arn is other.arn)
107+
& jnp.array_equal(self.log_scale_min_clip, other.log_scale_min_clip)
108+
& jnp.array_equal(self.log_scale_max_clip, other.log_scale_max_clip)
109+
)
110+
96111

97112
class BlockNeuralAutoregressiveTransform(Transform):
98113
"""
@@ -139,3 +154,12 @@ def log_abs_det_jacobian(self, x, y, intermediates=None):
139154
else:
140155
logdet = intermediates
141156
return logdet.sum(-1)
157+
158+
def tree_flatten(self):
159+
return (), ((), {"bn_arn": self.bn_arn})
160+
161+
def __eq__(self, other):
162+
return (
163+
isinstance(other, BlockNeuralAutoregressiveTransform)
164+
and self.bn_arn is other.bn_arn
165+
)

0 commit comments

Comments
 (0)