62
62
import numpy as np
63
63
64
64
import jax .numpy
65
+ import jax .numpy as jnp
66
+ from jax .tree_util import register_pytree_node
65
67
66
68
67
69
class Constraint (object ):
@@ -75,6 +77,10 @@ class Constraint(object):
75
77
is_discrete = False
76
78
event_dim = 0
77
79
80
+ def __init_subclass__ (cls , ** kwargs ):
81
+ super ().__init_subclass__ (** kwargs )
82
+ register_pytree_node (cls , cls .tree_flatten , cls .tree_unflatten )
83
+
78
84
def __call__ (self , x ):
79
85
raise NotImplementedError
80
86
@@ -94,8 +100,24 @@ def feasible_like(self, prototype):
94
100
"""
95
101
raise NotImplementedError
96
102
103
+ @classmethod
104
+ def tree_unflatten (cls , aux_data , params ):
105
+ params_keys , aux_data = aux_data
106
+ self = cls .__new__ (cls )
107
+ for k , v in zip (params_keys , params ):
108
+ setattr (self , k , v )
109
+
110
+ for k , v in aux_data .items ():
111
+ setattr (self , k , v )
112
+ return self
113
+
114
+
115
+ class ParameterFreeConstraint (Constraint ):
116
+ def tree_flatten (self ):
117
+ return (), ((), dict ())
118
+
97
119
98
- class _SingletonConstraint (Constraint ):
120
+ class _SingletonConstraint (ParameterFreeConstraint ):
99
121
"""
100
122
A constraint type which has only one canonical instance, like constraints.real,
101
123
and unlike constraints.interval.
@@ -202,8 +224,23 @@ def __call__(self, x=None, *, is_discrete=NotImplemented, event_dim=NotImplement
202
224
event_dim = self ._event_dim
203
225
return _Dependent (is_discrete = is_discrete , event_dim = event_dim )
204
226
227
+ def __eq__ (self , other ):
228
+ return (
229
+ type (self ) is type (other )
230
+ and self ._is_discrete == other ._is_discrete
231
+ and self ._event_dim == other ._event_dim
232
+ )
233
+
234
+ def tree_flatten (self ):
235
+ return (), (
236
+ (),
237
+ dict (_is_discrete = self ._is_discrete , _event_dim = self ._event_dim ),
238
+ )
239
+
205
240
206
241
class dependent_property (property , _Dependent ):
242
+ # XXX: this should not need to be pytree-able since it simply wraps a method
243
+ # and thus is automatically present once the method's object is created
207
244
def __init__ (
208
245
self , fn = None , * , is_discrete = NotImplemented , event_dim = NotImplemented
209
246
):
@@ -243,8 +280,16 @@ def __repr__(self):
243
280
def feasible_like (self , prototype ):
244
281
return jax .numpy .broadcast_to (self .lower_bound + 1 , jax .numpy .shape (prototype ))
245
282
283
+ def tree_flatten (self ):
284
+ return (self .lower_bound ,), (("lower_bound" ,), dict ())
285
+
286
+ def __eq__ (self , other ):
287
+ if not isinstance (other , _GreaterThan ):
288
+ return False
289
+ return jnp .array_equal (self .lower_bound , other .lower_bound )
246
290
247
- class _Positive (_GreaterThan , _SingletonConstraint ):
291
+
292
+ class _Positive (_SingletonConstraint , _GreaterThan ):
248
293
def __init__ (self ):
249
294
super ().__init__ (0.0 )
250
295
@@ -301,6 +346,20 @@ def __repr__(self):
301
346
def feasible_like (self , prototype ):
302
347
return self .base_constraint .feasible_like (prototype )
303
348
349
+ def tree_flatten (self ):
350
+ return (self .base_constraint ,), (
351
+ ("base_constraint" ,),
352
+ {"reinterpreted_batch_ndims" : self .reinterpreted_batch_ndims },
353
+ )
354
+
355
+ def __eq__ (self , other ):
356
+ if not isinstance (other , _IndependentConstraint ):
357
+ return False
358
+
359
+ return (self .base_constraint == other .base_constraint ) & (
360
+ self .reinterpreted_batch_ndims == other .reinterpreted_batch_ndims
361
+ )
362
+
304
363
305
364
class _RealVector (_IndependentConstraint , _SingletonConstraint ):
306
365
def __init__ (self ):
@@ -327,6 +386,14 @@ def __repr__(self):
327
386
def feasible_like (self , prototype ):
328
387
return jax .numpy .broadcast_to (self .upper_bound - 1 , jax .numpy .shape (prototype ))
329
388
389
+ def tree_flatten (self ):
390
+ return (self .upper_bound ,), (("upper_bound" ,), dict ())
391
+
392
+ def __eq__ (self , other ):
393
+ if not isinstance (other , _LessThan ):
394
+ return False
395
+ return jnp .array_equal (self .upper_bound , other .upper_bound )
396
+
330
397
331
398
class _IntegerInterval (Constraint ):
332
399
is_discrete = True
@@ -348,6 +415,20 @@ def __repr__(self):
348
415
def feasible_like (self , prototype ):
349
416
return jax .numpy .broadcast_to (self .lower_bound , jax .numpy .shape (prototype ))
350
417
418
+ def tree_flatten (self ):
419
+ return (self .lower_bound , self .upper_bound ), (
420
+ ("lower_bound" , "upper_bound" ),
421
+ dict (),
422
+ )
423
+
424
+ def __eq__ (self , other ):
425
+ if not isinstance (other , _IntegerInterval ):
426
+ return False
427
+
428
+ return jnp .array_equal (self .lower_bound , other .lower_bound ) & jnp .array_equal (
429
+ self .upper_bound , other .upper_bound
430
+ )
431
+
351
432
352
433
class _IntegerGreaterThan (Constraint ):
353
434
is_discrete = True
@@ -366,13 +447,21 @@ def __repr__(self):
366
447
def feasible_like (self , prototype ):
367
448
return jax .numpy .broadcast_to (self .lower_bound , jax .numpy .shape (prototype ))
368
449
450
+ def tree_flatten (self ):
451
+ return (self .lower_bound ,), (("lower_bound" ,), dict ())
369
452
370
- class _IntegerPositive (_IntegerGreaterThan , _SingletonConstraint ):
453
+ def __eq__ (self , other ):
454
+ if not isinstance (other , _IntegerGreaterThan ):
455
+ return False
456
+ return jnp .array_equal (self .lower_bound , other .lower_bound )
457
+
458
+
459
+ class _IntegerPositive (_SingletonConstraint , _IntegerGreaterThan ):
371
460
def __init__ (self ):
372
461
super ().__init__ (1 )
373
462
374
463
375
- class _IntegerNonnegative (_IntegerGreaterThan , _SingletonConstraint ):
464
+ class _IntegerNonnegative (_SingletonConstraint , _IntegerGreaterThan ):
376
465
def __init__ (self ):
377
466
super ().__init__ (0 )
378
467
@@ -398,19 +487,25 @@ def feasible_like(self, prototype):
398
487
)
399
488
400
489
def __eq__ (self , other ):
401
- return (
402
- isinstance ( other , _Interval )
403
- and self .lower_bound == other .lower_bound
404
- and self .upper_bound == other .upper_bound
490
+ if not isinstance ( other , _Interval ):
491
+ return False
492
+ return jnp . array_equal ( self .lower_bound , other .lower_bound ) & jnp . array_equal (
493
+ self .upper_bound , other .upper_bound
405
494
)
406
495
496
+ def tree_flatten (self ):
497
+ return (self .lower_bound , self .upper_bound ), (
498
+ ("lower_bound" , "upper_bound" ),
499
+ dict (),
500
+ )
407
501
408
- class _Circular (_Interval , _SingletonConstraint ):
502
+
503
+ class _Circular (_SingletonConstraint , _Interval ):
409
504
def __init__ (self ):
410
505
super ().__init__ (- math .pi , math .pi )
411
506
412
507
413
- class _UnitInterval (_Interval , _SingletonConstraint ):
508
+ class _UnitInterval (_SingletonConstraint , _Interval ):
414
509
def __init__ (self ):
415
510
super ().__init__ (0.0 , 1.0 )
416
511
@@ -462,6 +557,14 @@ def feasible_like(self, prototype):
462
557
value = jax .numpy .pad (jax .numpy .expand_dims (self .upper_bound , - 1 ), pad_width )
463
558
return jax .numpy .broadcast_to (value , prototype .shape )
464
559
560
+ def tree_flatten (self ):
561
+ return (self .upper_bound ,), (("upper_bound" ,), dict ())
562
+
563
+ def __eq__ (self , other ):
564
+ if not isinstance (other , _Multinomial ):
565
+ return False
566
+ return jnp .array_equal (self .upper_bound , other .upper_bound )
567
+
465
568
466
569
class _L1Ball (_SingletonConstraint ):
467
570
"""
@@ -546,7 +649,7 @@ def feasible_like(self, prototype):
546
649
return jax .numpy .full_like (prototype , 1 / prototype .shape [- 1 ])
547
650
548
651
549
- class _SoftplusPositive (_GreaterThan , _SingletonConstraint ):
652
+ class _SoftplusPositive (_SingletonConstraint , _GreaterThan ):
550
653
def __init__ (self ):
551
654
super ().__init__ (lower_bound = 0.0 )
552
655
0 commit comments