Skip to content

Commit a4222e9

Browse files
committed
remove TODOs & unused commented code; create model and src files
1 parent d1c1f3d commit a4222e9

File tree

17 files changed

+991
-116
lines changed

17 files changed

+991
-116
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22

33
⚠️ This is a work in progress.
44

5-
_`pyrenew-flu-light` is an instantiation of an [Epidemia](https://imperialcollegelondon.github.io/epidemia/) influenza forecasting model in [PyRenew](https://github.com/CDCgov/PyRenew)_
6-
7-
5+
_`pyrenew-flu-light` is an instantiation of an [Epidemia](https://imperialcollegelondon.github.io/epidemia/) influenza forecasting model in [PyRenew](https://github.com/CDCgov/PyRenew)._
86

97
NOTE: Presently, this `pyrenew-flu-light` cannot be installed and used with current NHSN, as its author is validating it on historical influenza data, which is .
108

assets/paste_bin.txt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,45 @@
22
NOTES
33

44

5+
6+
REMOVE plot and comparison functions for now
7+
8+
# ax.set_title("Posterior Predictive Plot")
9+
# ax.set_ylabel("Hospital Admissions")
10+
# ax.set_xlabel("Days")
11+
# plt.show()
12+
13+
# prior_p_ss_figures_and_descriptions = plot_sample_variables(
14+
# samples=prior_p_ss,
15+
# variables=["Rts", "latent_infections", "negbinom_rv"],
16+
# observations=obs,
17+
# ylabels=[
18+
# "Basic Reproduction Number",
19+
# "Latent Infections",
20+
# "Hospital Admissions",
21+
# ],
22+
# plot_types=["TRACE", "PPC", "HDI"],
23+
# plot_kwargs={
24+
# "HDI": {"hdi_prob": 0.95, "plot_kwargs": {"ls": "-."}},
25+
# "TRACE": {"var_names": ["Rts", "latent_infections"]},
26+
# "PPC": {"alpha": 0.05, "textsize": 12},
27+
# },
28+
# )
29+
30+
# print(prior_p_ss_figures_and_descriptions)
31+
32+
# if args.forecasting:
33+
34+
# prior_p_ss & post_p_ss get their own pdf (markdown first then subprocess)
35+
# each variable is plotted out, if possible
36+
# arviz diagnostics
37+
38+
39+
40+
41+
42+
43+
544
seeding ("initialization" in MSR lingo):
645

746
no renewal process, no need for a defined R(t)

src/model/__init__.py

Whitespace-only changes.

src/model/inf.py

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import logging
2+
3+
import jax.numpy as jnp
4+
import numpy as np
5+
import numpyro
6+
import numpyro.distributions as dist
7+
from jax.typing import ArrayLike
8+
from pyrenew.latent import logistic_susceptibility_adjustment
9+
from pyrenew.metaclass import RandomVariable
10+
11+
12+
class CFAEPIM_Infections(RandomVariable):
13+
"""
14+
Class representing the infection process in
15+
the CFAEPIM model. This class handles the sampling of
16+
infection counts over time, considering the
17+
reproduction number, generation interval, and population size,
18+
while accounting for susceptibility depletion.
19+
20+
Parameters
21+
----------
22+
I0 : ArrayLike
23+
Initial infection counts.
24+
susceptibility_prior : numpyro.distributions
25+
Prior distribution for the susceptibility proportion
26+
(S_{v-1} / P).
27+
"""
28+
29+
def __init__(
30+
self,
31+
I0: ArrayLike,
32+
susceptibility_prior: numpyro.distributions,
33+
): # numpydoc ignore=GL08
34+
logging.info("Initializing CFAEPIM_Infections")
35+
36+
self.I0 = I0
37+
self.susceptibility_prior = susceptibility_prior
38+
39+
@staticmethod
40+
def validate(I0: any, susceptibility_prior: any) -> None:
41+
"""
42+
Validate the parameters of the
43+
infection process. Checks that the initial infections
44+
(I0) and susceptibility_prior are
45+
correctly specified. If any parameter is invalid,
46+
an appropriate error is raised.
47+
48+
Raises
49+
------
50+
TypeError
51+
If I0 is not array-like or
52+
susceptibility_prior is not
53+
a numpyro distribution.
54+
"""
55+
logging.info("Validating CFAEPIM_Infections parameters")
56+
if not isinstance(I0, (np.ndarray, jnp.ndarray)):
57+
raise TypeError(
58+
f"Initial infections (I0) must be an array-like structure; was type {type(I0)}"
59+
)
60+
61+
if not isinstance(susceptibility_prior, dist.Distribution):
62+
raise TypeError(
63+
f"susceptibility_prior must be a numpyro distribution; was type {type(susceptibility_prior)}"
64+
)
65+
66+
def sample(
67+
self, Rt: ArrayLike, gen_int: ArrayLike, P: float, **kwargs
68+
) -> tuple:
69+
"""
70+
Given an array of reproduction numbers,
71+
a generation interval, and the size of a
72+
jurisdiction's population,
73+
calculate infections under the scheme
74+
of susceptible depletion.
75+
76+
Parameters
77+
----------
78+
Rt : ArrayLike
79+
Reproduction numbers over time; this is an array of
80+
Rt values for each time step.
81+
gen_int : ArrayLike
82+
Generation interval probability mass function. This is
83+
an array of probabilities representing the
84+
distribution of times between successive infections
85+
in a chain of transmission.
86+
P : float
87+
Population size. This is the total population
88+
size used for susceptibility adjustment.
89+
**kwargs : dict, optional
90+
Additional keyword arguments passed through to internal
91+
sample calls, should there be any.
92+
93+
Returns
94+
-------
95+
tuple
96+
A tuple containing two arrays: all_I_t, an array of
97+
latent infections at each time step and all_S_t, an
98+
array of susceptible individuals at each time step.
99+
100+
Raises
101+
------
102+
ValueError
103+
If the length of the initial infections
104+
vector (I0) is less than the length of
105+
the generation interval.
106+
"""
107+
108+
# get initial infections
109+
I0_samples = self.I0.sample()
110+
I0 = I0_samples[0].value
111+
112+
logging.debug(f"I0 samples: {I0}")
113+
114+
# reverse generation interval (recency)
115+
gen_int_rev = jnp.flip(gen_int)
116+
117+
if I0.size < gen_int.size:
118+
raise ValueError(
119+
"Initial infections vector must be at least as long as "
120+
"the generation interval. "
121+
f"Initial infections vector length: {I0.size}, "
122+
f"generation interval length: {gen_int.size}."
123+
)
124+
recent_I0 = I0[-gen_int_rev.size :]
125+
126+
# sample the initial susceptible population proportion S_{v-1} / P from prior
127+
init_S_proportion = numpyro.sample(
128+
"S_v_minus_1_over_P", self.susceptibility_prior
129+
)
130+
logging.debug(f"Initial susceptible proportion: {init_S_proportion}")
131+
132+
# calculate initial susceptible population S_{v-1}
133+
init_S = init_S_proportion * P
134+
135+
def update_infections(carry, Rt): # numpydoc ignore=GL08
136+
S_t, I_recent = carry
137+
138+
# compute raw infections
139+
i_raw_t = Rt * jnp.dot(I_recent, gen_int_rev)
140+
141+
# apply the logistic susceptibility adjustment to a potential new incidence
142+
i_t = logistic_susceptibility_adjustment(
143+
I_raw_t=i_raw_t, frac_susceptible=S_t / P, n_population=P
144+
)
145+
146+
# update susceptible population
147+
S_t -= i_t
148+
149+
# update infections
150+
I_recent = jnp.concatenate([I_recent[:-1], jnp.array([i_t])])
151+
152+
return (S_t, I_recent), i_t
153+
154+
# initial carry state
155+
init_carry = (init_S, recent_I0)
156+
157+
# scan to iterate over time steps and update infections
158+
(all_S_t, _), all_I_t = numpyro.contrib.control_flow.scan(
159+
update_infections, init_carry, Rt
160+
)
161+
162+
logging.debug(f"All infections: {all_I_t}")
163+
logging.debug(f"All susceptibles: {all_S_t}")
164+
165+
return all_I_t, all_S_t

0 commit comments

Comments
 (0)