|
25 | 25 | # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
26 | 26 | # POSSIBILITY OF SUCH DAMAGE.
|
27 | 27 |
|
28 |
| - |
29 | 28 | from jax import lax
|
30 | 29 | import jax.nn as nn
|
31 | 30 | import jax.numpy as jnp
|
|
67 | 66 | )
|
68 | 67 |
|
69 | 68 |
|
| 69 | +class AsymmetricLaplace(Distribution): |
| 70 | + arg_constraints = { |
| 71 | + "loc": constraints.real, |
| 72 | + "scale": constraints.positive, |
| 73 | + "asymmetry": constraints.positive, |
| 74 | + } |
| 75 | + reparametrized_params = ["loc", "scale", "asymmetry"] |
| 76 | + support = constraints.real |
| 77 | + |
| 78 | + def __init__(self, loc=0.0, scale=1.0, asymmetry=1.0, validate_args=None): |
| 79 | + batch_shape = lax.broadcast_shapes( |
| 80 | + jnp.shape(loc), jnp.shape(scale), jnp.shape(asymmetry) |
| 81 | + ) |
| 82 | + self.loc, self.scale, self.asymmetry = promote_shapes( |
| 83 | + loc, scale, asymmetry, shape=batch_shape |
| 84 | + ) |
| 85 | + super(AsymmetricLaplace, self).__init__( |
| 86 | + batch_shape=batch_shape, validate_args=validate_args |
| 87 | + ) |
| 88 | + |
| 89 | + @lazy_property |
| 90 | + def left_scale(self): |
| 91 | + return self.scale * self.asymmetry |
| 92 | + |
| 93 | + @lazy_property |
| 94 | + def right_scale(self): |
| 95 | + return self.scale / self.asymmetry |
| 96 | + |
| 97 | + def log_prob(self, value): |
| 98 | + if self._validate_args: |
| 99 | + self._validate_sample(value) |
| 100 | + z = value - self.loc |
| 101 | + z = -jnp.abs(z) / jnp.where(z < 0, self.left_scale, self.right_scale) |
| 102 | + return z - jnp.log(self.left_scale + self.right_scale) |
| 103 | + |
| 104 | + def sample(self, key, sample_shape=()): |
| 105 | + assert is_prng_key(key) |
| 106 | + shape = (2,) + sample_shape + self.batch_shape + self.event_shape |
| 107 | + u, v = random.exponential(key, shape=shape) |
| 108 | + return self.loc - self.left_scale * u + self.right_scale * v |
| 109 | + |
| 110 | + @property |
| 111 | + def mean(self): |
| 112 | + total_scale = self.left_scale + self.right_scale |
| 113 | + mean = self.loc + (self.right_scale**2 - self.left_scale**2) / total_scale |
| 114 | + return jnp.broadcast_to(mean, self.batch_shape) |
| 115 | + |
| 116 | + @property |
| 117 | + def variance(self): |
| 118 | + left = self.left_scale |
| 119 | + right = self.right_scale |
| 120 | + total = left + right |
| 121 | + p = left / total |
| 122 | + q = right / total |
| 123 | + variance = p * left**2 + q * right**2 + p * q * total**2 |
| 124 | + return jnp.broadcast_to(variance, self.batch_shape) |
| 125 | + |
| 126 | + def cdf(self, value): |
| 127 | + z = value - self.loc |
| 128 | + k = self.asymmetry |
| 129 | + return jnp.where( |
| 130 | + z >= 0, |
| 131 | + 1 - (1 / (1 + k**2)) * jnp.exp(-jnp.abs(z) / self.right_scale), |
| 132 | + k**2 / (1 + k**2) * jnp.exp(-jnp.abs(z) / self.left_scale), |
| 133 | + ) |
| 134 | + |
| 135 | + def icdf(self, value): |
| 136 | + k = self.asymmetry |
| 137 | + temp = k**2 / (1 + k**2) |
| 138 | + return jnp.where( |
| 139 | + value <= temp, |
| 140 | + self.loc + self.left_scale * jnp.log(value / temp), |
| 141 | + self.loc - self.right_scale * jnp.log((1 + k**2) * (1 - value)), |
| 142 | + ) |
| 143 | + |
| 144 | + |
70 | 145 | class Beta(Distribution):
|
71 | 146 | arg_constraints = {
|
72 | 147 | "concentration1": constraints.positive,
|
@@ -1777,3 +1852,64 @@ def __init__(self, mean, concentration, validate_args=None):
|
1777 | 1852 | (1.0 - mean) * concentration,
|
1778 | 1853 | validate_args=validate_args,
|
1779 | 1854 | )
|
| 1855 | + |
| 1856 | + |
| 1857 | +class AsymmetricLaplaceQuantile(Distribution): |
| 1858 | + """An alternative parameterization of AsymmetricLaplace commonly applied in |
| 1859 | + Bayesian quantile regression. |
| 1860 | +
|
| 1861 | + Instead of the `asymmetry` parameter employed by AsymmetricLaplace, to |
| 1862 | + define the balance between left- versus right-hand sides of the |
| 1863 | + distribution, this class utilizes a `quantile` parameter, which describes |
| 1864 | + the proportion of probability density that falls to the left-hand side of |
| 1865 | + the distribution. |
| 1866 | +
|
| 1867 | + The `scale` parameter is also interpreted slightly differently than in |
| 1868 | + AsymmetricLaplce. When `loc=0` and `scale=1`, AsymmetricLaplace(0,1,1) |
| 1869 | + is equivalent to Laplace(0,1), while AsymmetricLaplaceQuantile(0,1,0.5) is |
| 1870 | + equivalent to Laplace(0,2). |
| 1871 | + """ |
| 1872 | + |
| 1873 | + arg_constraints = { |
| 1874 | + "loc": constraints.real, |
| 1875 | + "scale": constraints.positive, |
| 1876 | + "quantile": constraints.open_interval(0.0, 1.0), |
| 1877 | + } |
| 1878 | + reparametrized_params = ["loc", "scale", "quantile"] |
| 1879 | + support = constraints.real |
| 1880 | + |
| 1881 | + def __init__(self, loc=0.0, scale=1.0, quantile=0.5, validate_args=None): |
| 1882 | + batch_shape = lax.broadcast_shapes( |
| 1883 | + jnp.shape(loc), jnp.shape(scale), jnp.shape(quantile) |
| 1884 | + ) |
| 1885 | + self.loc, self.scale, self.quantile = promote_shapes( |
| 1886 | + loc, scale, quantile, shape=batch_shape |
| 1887 | + ) |
| 1888 | + super(AsymmetricLaplaceQuantile, self).__init__( |
| 1889 | + batch_shape=batch_shape, validate_args=validate_args |
| 1890 | + ) |
| 1891 | + asymmetry = (1 / ((1 / quantile) - 1)) ** 0.5 |
| 1892 | + scale_classic = scale * asymmetry / quantile |
| 1893 | + self._ald = AsymmetricLaplace(loc=loc, scale=scale_classic, asymmetry=asymmetry) |
| 1894 | + |
| 1895 | + def log_prob(self, value): |
| 1896 | + if self._validate_args: |
| 1897 | + self._validate_sample(value) |
| 1898 | + return self._ald.log_prob(value) |
| 1899 | + |
| 1900 | + def sample(self, key, sample_shape=()): |
| 1901 | + return self._ald.sample(key, sample_shape=sample_shape) |
| 1902 | + |
| 1903 | + @property |
| 1904 | + def mean(self): |
| 1905 | + return self._ald.mean |
| 1906 | + |
| 1907 | + @property |
| 1908 | + def variance(self): |
| 1909 | + return self._ald.variance |
| 1910 | + |
| 1911 | + def cdf(self, value): |
| 1912 | + return self._ald.cdf(value) |
| 1913 | + |
| 1914 | + def icdf(self, value): |
| 1915 | + return self._ald.icdf(value) |
0 commit comments