Skip to content

Commit f9c756c

Browse files
authored
AsymmetricLaplace distributions (#1332)
* AsymmetricLaplace distributions, related to #1319 * Adding ALD and ALDQ distributions to be importable from numpyro.distributions * fixes typo in __all__ for ALDQ * Updating tests, docs, and converting ALDQ to ALD under the hood * fixing qscale typo in reparameterized_params of ALDQ * Rewritten cdf, icdf, and fixing batching dims, updating tests * Reordering args, removing gamma_params from testing
1 parent 38b0d48 commit f9c756c

File tree

4 files changed

+172
-2
lines changed

4 files changed

+172
-2
lines changed

docs/source/distributions.rst

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,22 @@ Unit
8080
Continuous Distributions
8181
------------------------
8282

83+
AsymmetricLaplace
84+
^^^^^^^^^^^^^^^^^
85+
.. autoclass:: numpyro.distributions.continuous.AsymmetricLaplace
86+
:members:
87+
:undoc-members:
88+
:show-inheritance:
89+
:member-order: bysource
90+
91+
AsymmetricLaplaceQuantile
92+
^^^^^^^^^^^^^^^^^^^^^^^^^
93+
.. autoclass:: numpyro.distributions.continuous.AsymmetricLaplaceQuantile
94+
:members:
95+
:undoc-members:
96+
:show-inheritance:
97+
:member-order: bysource
98+
8399
Beta
84100
^^^^
85101
.. autoclass:: numpyro.distributions.continuous.Beta
@@ -809,7 +825,7 @@ ExpTransform
809825
:undoc-members:
810826
:show-inheritance:
811827
:member-order: bysource
812-
828+
813829
IdentityTransform
814830
^^^^^^^^^^^^^^^^^
815831
.. autoclass:: numpyro.distributions.transforms.IdentityTransform

numpyro/distributions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
)
1313
from numpyro.distributions.continuous import (
1414
LKJ,
15+
AsymmetricLaplace,
16+
AsymmetricLaplaceQuantile,
1517
Beta,
1618
BetaProportion,
1719
Cauchy,
@@ -100,6 +102,8 @@
100102
"constraints",
101103
"kl_divergence",
102104
"transforms",
105+
"AsymmetricLaplace",
106+
"AsymmetricLaplaceQuantile",
103107
"Bernoulli",
104108
"BernoulliLogits",
105109
"BernoulliProbs",

numpyro/distributions/continuous.py

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
2626
# POSSIBILITY OF SUCH DAMAGE.
2727

28-
2928
from jax import lax
3029
import jax.nn as nn
3130
import jax.numpy as jnp
@@ -67,6 +66,82 @@
6766
)
6867

6968

69+
class AsymmetricLaplace(Distribution):
70+
arg_constraints = {
71+
"loc": constraints.real,
72+
"scale": constraints.positive,
73+
"asymmetry": constraints.positive,
74+
}
75+
reparametrized_params = ["loc", "scale", "asymmetry"]
76+
support = constraints.real
77+
78+
def __init__(self, loc=0.0, scale=1.0, asymmetry=1.0, validate_args=None):
79+
batch_shape = lax.broadcast_shapes(
80+
jnp.shape(loc), jnp.shape(scale), jnp.shape(asymmetry)
81+
)
82+
self.loc, self.scale, self.asymmetry = promote_shapes(
83+
loc, scale, asymmetry, shape=batch_shape
84+
)
85+
super(AsymmetricLaplace, self).__init__(
86+
batch_shape=batch_shape, validate_args=validate_args
87+
)
88+
89+
@lazy_property
90+
def left_scale(self):
91+
return self.scale * self.asymmetry
92+
93+
@lazy_property
94+
def right_scale(self):
95+
return self.scale / self.asymmetry
96+
97+
def log_prob(self, value):
98+
if self._validate_args:
99+
self._validate_sample(value)
100+
z = value - self.loc
101+
z = -jnp.abs(z) / jnp.where(z < 0, self.left_scale, self.right_scale)
102+
return z - jnp.log(self.left_scale + self.right_scale)
103+
104+
def sample(self, key, sample_shape=()):
105+
assert is_prng_key(key)
106+
shape = (2,) + sample_shape + self.batch_shape + self.event_shape
107+
u, v = random.exponential(key, shape=shape)
108+
return self.loc - self.left_scale * u + self.right_scale * v
109+
110+
@property
111+
def mean(self):
112+
total_scale = self.left_scale + self.right_scale
113+
mean = self.loc + (self.right_scale**2 - self.left_scale**2) / total_scale
114+
return jnp.broadcast_to(mean, self.batch_shape)
115+
116+
@property
117+
def variance(self):
118+
left = self.left_scale
119+
right = self.right_scale
120+
total = left + right
121+
p = left / total
122+
q = right / total
123+
variance = p * left**2 + q * right**2 + p * q * total**2
124+
return jnp.broadcast_to(variance, self.batch_shape)
125+
126+
def cdf(self, value):
127+
z = value - self.loc
128+
k = self.asymmetry
129+
return jnp.where(
130+
z >= 0,
131+
1 - (1 / (1 + k**2)) * jnp.exp(-jnp.abs(z) / self.right_scale),
132+
k**2 / (1 + k**2) * jnp.exp(-jnp.abs(z) / self.left_scale),
133+
)
134+
135+
def icdf(self, value):
136+
k = self.asymmetry
137+
temp = k**2 / (1 + k**2)
138+
return jnp.where(
139+
value <= temp,
140+
self.loc + self.left_scale * jnp.log(value / temp),
141+
self.loc - self.right_scale * jnp.log((1 + k**2) * (1 - value)),
142+
)
143+
144+
70145
class Beta(Distribution):
71146
arg_constraints = {
72147
"concentration1": constraints.positive,
@@ -1777,3 +1852,64 @@ def __init__(self, mean, concentration, validate_args=None):
17771852
(1.0 - mean) * concentration,
17781853
validate_args=validate_args,
17791854
)
1855+
1856+
1857+
class AsymmetricLaplaceQuantile(Distribution):
1858+
"""An alternative parameterization of AsymmetricLaplace commonly applied in
1859+
Bayesian quantile regression.
1860+
1861+
Instead of the `asymmetry` parameter employed by AsymmetricLaplace, to
1862+
define the balance between left- versus right-hand sides of the
1863+
distribution, this class utilizes a `quantile` parameter, which describes
1864+
the proportion of probability density that falls to the left-hand side of
1865+
the distribution.
1866+
1867+
The `scale` parameter is also interpreted slightly differently than in
1868+
AsymmetricLaplce. When `loc=0` and `scale=1`, AsymmetricLaplace(0,1,1)
1869+
is equivalent to Laplace(0,1), while AsymmetricLaplaceQuantile(0,1,0.5) is
1870+
equivalent to Laplace(0,2).
1871+
"""
1872+
1873+
arg_constraints = {
1874+
"loc": constraints.real,
1875+
"scale": constraints.positive,
1876+
"quantile": constraints.open_interval(0.0, 1.0),
1877+
}
1878+
reparametrized_params = ["loc", "scale", "quantile"]
1879+
support = constraints.real
1880+
1881+
def __init__(self, loc=0.0, scale=1.0, quantile=0.5, validate_args=None):
1882+
batch_shape = lax.broadcast_shapes(
1883+
jnp.shape(loc), jnp.shape(scale), jnp.shape(quantile)
1884+
)
1885+
self.loc, self.scale, self.quantile = promote_shapes(
1886+
loc, scale, quantile, shape=batch_shape
1887+
)
1888+
super(AsymmetricLaplaceQuantile, self).__init__(
1889+
batch_shape=batch_shape, validate_args=validate_args
1890+
)
1891+
asymmetry = (1 / ((1 / quantile) - 1)) ** 0.5
1892+
scale_classic = scale * asymmetry / quantile
1893+
self._ald = AsymmetricLaplace(loc=loc, scale=scale_classic, asymmetry=asymmetry)
1894+
1895+
def log_prob(self, value):
1896+
if self._validate_args:
1897+
self._validate_sample(value)
1898+
return self._ald.log_prob(value)
1899+
1900+
def sample(self, key, sample_shape=()):
1901+
return self._ald.sample(key, sample_shape=sample_shape)
1902+
1903+
@property
1904+
def mean(self):
1905+
return self._ald.mean
1906+
1907+
@property
1908+
def variance(self):
1909+
return self._ald.variance
1910+
1911+
def cdf(self, value):
1912+
return self._ald.cdf(value)
1913+
1914+
def icdf(self, value):
1915+
return self._ald.icdf(value)

test/test_distributions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ def tree_unflatten(cls, aux_data, params):
196196

197197

198198
_DIST_MAP = {
199+
dist.AsymmetricLaplace: lambda loc, scale, asymmetry: osp.laplace_asymmetric(
200+
asymmetry, loc=loc, scale=scale
201+
),
199202
dist.BernoulliProbs: lambda probs: osp.bernoulli(p=probs),
200203
dist.BernoulliLogits: lambda logits: osp.bernoulli(p=_to_probs_bernoulli(logits)),
201204
dist.Beta: lambda con1, con0: osp.beta(con1, con0),
@@ -253,6 +256,17 @@ def get_sp_dist(jax_dist):
253256

254257

255258
CONTINUOUS = [
259+
T(dist.AsymmetricLaplace, 1.0, 0.5, 1.0),
260+
T(dist.AsymmetricLaplace, np.array([1.0, 2.0]), 2.0, 2.0),
261+
T(dist.AsymmetricLaplace, np.array([[1.0], [2.0]]), 2.0, np.array([3.0, 5.0])),
262+
T(dist.AsymmetricLaplaceQuantile, 0.0, 1.0, 0.5),
263+
T(dist.AsymmetricLaplaceQuantile, np.array([1.0, 2.0]), 2.0, 0.7),
264+
T(
265+
dist.AsymmetricLaplaceQuantile,
266+
np.array([[1.0], [2.0]]),
267+
2.0,
268+
np.array([0.2, 0.8]),
269+
),
256270
T(dist.Beta, 0.2, 1.1),
257271
T(dist.Beta, 1.0, np.array([2.0, 2.0])),
258272
T(dist.Beta, 1.0, np.array([[1.0, 1.0], [2.0, 2.0]])),

0 commit comments

Comments
 (0)