Skip to content

Commit 7d50393

Browse files
gh-1806: Implementation of Doubly Truncated Power Law and Lower Truncated Power Law (#1807)
* implementation of DoublyTruncatedPowerLaw * implementation of LowerTruncatedPowerLaw * chore: mathematical description in docstrings * chore: mathematical details of `LowerTruncatedPowerLaw` * chore: Fix bug in DoublyTruncatedPowerLaw cdf and icdf calculation * chore: Refactor mean and variance calculation by using kth-moment in DoublyTruncatedPowerLaw * chore: Refactor mean and variance calculation in LowerTruncatedPowerLaw * chore: masking in icdf of LowerTruncatedPowerLaw * chore: entropy of LowerTruncatedPowerLaw * chore: `lax.sqaure` replaced with `jnp.sqaure` * chore: moments and entropy were extra and removed * chore: unit tests * fix: nan gradients fixed, values still diverging * Updated UpperTruncatedPowerLaw with adequate derivations, including fixing the discontinuity point for alpha equals minus one and correct tangents for lower and upper bounds. * Changed constrains of alpha of LowerTruncatedPowerLaw to the smaller minus one and also changed equation to keep equational uniformity UpperTruncatedPowerLaw to and improved stability of the calculation by integrating formula parts inside log to prevent NaN values due to negative values of some parts that could balance out in the end. * chore: code and docstring formated * chore: equation refactor and simplified * chore: equation refactor and simplified * chore: use numpy arrays and numpy constants * chore: high precision computation enable for powerlaws * chore: `__name__` attribute calls removed * chore: powerlaws shifted with truncated distributions * chore: spelling mistakes fixed with code spell checker pre-commit hook * fix typo: perforance->perforamce->performance * chore: explicit enabling/disabling of 64bit floating point numbers * chore: disable everytime and enable x64 for power laws * chore: disable x64 for every test * chore: linked explanation in comments for disabling x64 for future reference for devs * chore: high precision test handeled efficiently for DoublyTruncatedPowerLaw * chore: high precision exception handled in test_log_prob_gradient --------- Co-authored-by: David Ziegler <[email protected]>
1 parent ca8fb39 commit 7d50393

24 files changed

+672
-85
lines changed

.github/workflows/ci.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,9 @@ jobs:
7272
- name: Test with pytest
7373
run: |
7474
CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
75+
- name: Test x64
76+
run: |
77+
JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k powerLaw
7578
7679
7780
test-inference:

.pre-commit-config.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,11 @@ repos:
2121
- id: check-yaml
2222
- id: check-added-large-files
2323
exclude: notebooks/*
24+
25+
- repo: https://github.com/codespell-project/codespell
26+
rev: v2.3.0
27+
hooks:
28+
- id: codespell
29+
stages: [commit, commit-msg]
30+
args:
31+
[--ignore-words-list, "Teh,aas", --check-filenames, --skip, "*.ipynb"]

README.md

Lines changed: 39 additions & 36 deletions
Large diffs are not rendered by default.

docker/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,5 @@ Design Choices:
3434
Future Work:
3535

3636
- Right now the jax, jaxlib, and numpyro versions are manually specified, so they have to be updated every NumPyro release. There are two ways forward for this:
37-
1. If there is a CI/CD in place to build and push images to a repository like Dockerhub, then the jax, jaxlib, and numpyro versions can be passed in as environment variables (for example, if something like [Drone CI](http://plugins.drone.io/drone-plugins/drone-docker/) is used). If implemented this way, the jax/jaxlib/numpyro versions will be ephemereal (not stored in source code).
37+
1. If there is a CI/CD in place to build and push images to a repository like Dockerhub, then the jax, jaxlib, and numpyro versions can be passed in as environment variables (for example, if something like [Drone CI](http://plugins.drone.io/drone-plugins/drone-docker/) is used). If implemented this way, the jax/jaxlib/numpyro versions will be ephemeral (not stored in source code).
3838
2. Alternative, one can create a Python script that will modify the Dockerfiles upon release accordingly (using a hook of some sort).

docs/source/distributions.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,14 @@ VonMises
662662
Truncated Distributions
663663
-----------------------
664664

665+
DoublyTruncatedPowerLaw
666+
^^^^^^^^^^^^^^^^^^^^^^^
667+
.. autoclass:: numpyro.distributions.truncated.DoublyTruncatedPowerLaw
668+
:members:
669+
:undoc-members:
670+
:show-inheritance:
671+
:member-order: bysource
672+
665673
LeftTruncatedDistribution
666674
^^^^^^^^^^^^^^^^^^^^^^^^^
667675
.. autoclass:: numpyro.distributions.truncated.LeftTruncatedDistribution
@@ -670,6 +678,14 @@ LeftTruncatedDistribution
670678
:show-inheritance:
671679
:member-order: bysource
672680

681+
LowerTruncatedPowerLaw
682+
^^^^^^^^^^^^^^^^^^^^^^
683+
.. autoclass:: numpyro.distributions.truncated.LowerTruncatedPowerLaw
684+
:members:
685+
:undoc-members:
686+
:show-inheritance:
687+
:member-order: bysource
688+
673689
RightTruncatedDistribution
674690
^^^^^^^^^^^^^^^^^^^^^^^^^^
675691
.. autoclass:: numpyro.distributions.truncated.RightTruncatedDistribution

examples/annotation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
1010
All models have discrete latent variables. Under the hood, we enumerate over
1111
(marginalize out) those discrete latent sites in inference. Those models have different
12-
complexity so they are great refererences for those who are new to Pyro/NumPyro
12+
complexity so they are great references for those who are new to Pyro/NumPyro
1313
enumeration mechanism. We recommend readers compare the implementations with the
1414
corresponding plate diagrams in [1] to see how concise a Pyro/NumPyro program is.
1515

examples/ar2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def run_inference(model, args, rng_key, y):
103103

104104

105105
def main(args):
106-
# generate artifical dataset
106+
# generate artificial dataset
107107
num_data = args.num_data
108108
rng_key = jax.random.PRNGKey(0)
109109
t = jnp.arange(0, num_data)

examples/holt_winters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def predict(model, args, samples, rng_key, y, n_seasons):
147147

148148

149149
def main(args):
150-
# generate artifical dataset
150+
# generate artificial dataset
151151
rng_key, _ = random.split(random.PRNGKey(0))
152152
T = args.T
153153
t = jnp.linspace(0, T + args.future, (T + args.future) * N_POINTS_PER_UNIT)

examples/mortality.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
dimensions of the age, space and time variables. This allows us to efficiently broadcast arrays
5858
in the likelihood.
5959
60-
As written above, the model includes a lot of centred random effects. The NUTS alogrithm benefits
60+
As written above, the model includes a lot of centred random effects. The NUTS algorithm benefits
6161
from a non-centred reparamatrisation to overcome difficult posterior geometries [2]. Rather than
6262
manually writing out the non-centred parametrisation, we make use of the NumPyro's automatic
6363
reparametrisation in :class:`~numpyro.infer.reparam.LocScaleReparam`.

numpyro/distributions/__init__.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@
9595
from numpyro.distributions.mixtures import Mixture, MixtureGeneral, MixtureSameFamily
9696
from numpyro.distributions.transforms import biject_to
9797
from numpyro.distributions.truncated import (
98+
DoublyTruncatedPowerLaw,
9899
LeftTruncatedDistribution,
100+
LowerTruncatedPowerLaw,
99101
RightTruncatedDistribution,
100102
TruncatedCauchy,
101103
TruncatedDistribution,
@@ -122,6 +124,7 @@
122124
"Binomial",
123125
"BinomialLogits",
124126
"BinomialProbs",
127+
"CAR",
125128
"Categorical",
126129
"CategoricalLogits",
127130
"CategoricalProbs",
@@ -132,9 +135,10 @@
132135
"DirichletMultinomial",
133136
"DiscreteUniform",
134137
"Distribution",
138+
"DoublyTruncatedPowerLaw",
135139
"EulerMaruyama",
136-
"Exponential",
137140
"ExpandedDistribution",
141+
"Exponential",
138142
"FoldedDistribution",
139143
"Gamma",
140144
"GammaPoisson",
@@ -152,29 +156,29 @@
152156
"Independent",
153157
"InverseGamma",
154158
"Kumaraswamy",
155-
"LKJ",
156-
"LKJCholesky",
157159
"Laplace",
158160
"LeftTruncatedDistribution",
161+
"LKJ",
162+
"LKJCholesky",
159163
"Logistic",
160164
"LogNormal",
161165
"LogUniform",
162-
"MatrixNormal",
166+
"LowerTruncatedPowerLaw",
167+
"LowRankMultivariateNormal",
163168
"MaskedDistribution",
169+
"MatrixNormal",
164170
"Mixture",
165-
"MixtureSameFamily",
166171
"MixtureGeneral",
172+
"MixtureSameFamily",
167173
"Multinomial",
168174
"MultinomialLogits",
169175
"MultinomialProbs",
170176
"MultivariateNormal",
171-
"CAR",
172177
"MultivariateStudentT",
173-
"LowRankMultivariateNormal",
174-
"Normal",
175-
"NegativeBinomialProbs",
176-
"NegativeBinomialLogits",
177178
"NegativeBinomial2",
179+
"NegativeBinomialLogits",
180+
"NegativeBinomialProbs",
181+
"Normal",
178182
"OrderedLogistic",
179183
"Pareto",
180184
"Poisson",
@@ -199,7 +203,7 @@
199203
"Wishart",
200204
"WishartCholesky",
201205
"ZeroInflatedDistribution",
202-
"ZeroInflatedPoisson",
203206
"ZeroInflatedNegativeBinomial2",
207+
"ZeroInflatedPoisson",
204208
"ZeroSumNormal",
205209
]

0 commit comments

Comments
 (0)