Bayesian estimation of prevalence models¶
This section describes how to use the prevalence models for Bayesian estimation in Episuite.
See also
- Bayesian modelling for COVID-19 seroprevalence studies
This is a talk that uses the same models implemented in Episuite.
- Estimating SARS-CoV-2 seroprevalence and epidemiological parameters with uncertainty from serological surveys
Excellent recent articule by Larremore et al. [LFB+20] on estimation for seroprevalence studies.
Episuite models are based on Numpyro, with uses Jax.
[1]:
import numpyro
import arviz as az
from numpyro.infer import MCMC, NUTS
from numpyro.infer import init_to_value, init_to_feasible
from matplotlib import pyplot as plt
from jax import random
from episuite import prevalence
# Set 2 cores in Numpyro
numpyro.set_host_device_count(2)
True prevalence model¶
In this section we will estimate a true prevalence model, a model that assumes that you’re observing true prevalences (i.e. on a seroprevalence study w/ perfect testing validation properties). Leter we will improve on it by assuming imperfect testing.
[2]:
num_warmup, num_samples = 500, 2000
[3]:
# Random generator needed by jax
rng_key = random.PRNGKey(42)
rng_key, rng_key_ = random.split(rng_key)
[4]:
# Scenario: collected 4000 samples and 20 were found positive
total_observations = 4000
positive_observations = 20
[5]:
# Configure MCMC with the true_prevalence_model from Episuite
kernel = NUTS(prevalence.true_prevalence_model, init_strategy=init_to_feasible())
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=1)
[6]:
# Run MCMC
mcmc.run(rng_key_,
obs_positive=positive_observations,
obs_total=total_observations)
sample: 100%|██████████| 2500/2500 [00:06<00:00, 407.97it/s, 1 steps of size 1.12e+00. acc. prob=0.91]
[7]:
samples = mcmc.get_samples()
mcmc.print_summary()
mean std median 5.0% 95.0% n_eff r_hat
true_p 0.01 0.00 0.01 0.00 0.01 903.02 1.00
Number of divergences: 0
[8]:
inference_data = az.from_numpyro(mcmc)
[9]:
az.plot_forest(inference_data)
plt.show()
[10]:
az.plot_trace(inference_data)
plt.show()
[11]:
az.plot_posterior(inference_data, round_to=3, point_estimate="mode")
plt.show()
Apparent prevalence model¶
In this section we will estimate an apparent prevalence model, a model that incorporates the sensitiviy and specificity properties of the test validation results. We will use here a scenario where we collected samples and tested for SARS-CoV-2 and assume properties from a real test from the brand Wondfo (used in Brazil on different seroprevalence surveys).
[12]:
# Wondfo test parameters (taken from their product description from tests they made with a PCR gold standard)
#
# From a total of 42 confirmed COVID-19 positive patients: the test detected 42 positive and 0 negative.
# From a total of 172 COVID 19 negative patients: the test detected 2 positive and 170 negative.
# Specificity parameters
n_sp = 172
x_sp = 170
# Sensitivity paramters
n_se = 42
x_se = 42
# These are results from a seroprevalence study in Brazil
observed_total = 4189
observed_positive = 2
[13]:
kernel = NUTS(prevalence.apparent_prevalence_model,
init_strategy=init_to_feasible())
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=1)
[14]:
mcmc.run(rng_key_,
x_se=x_se, n_se=n_se, # Sensitivity parameters of the test used
x_sp=x_sp, n_sp=n_sp, # Specificity parameters of the test used
obs_positive=observed_positive, # Positive results
obs_total=observed_total) # Total samples
sample: 100%|██████████| 2500/2500 [00:07<00:00, 352.70it/s, 7 steps of size 5.02e-01. acc. prob=0.91]
[15]:
mcmc.print_summary(exclude_deterministic=False)
samples_1 = mcmc.get_samples()
mean std median 5.0% 95.0% n_eff r_hat
apparent_p 0.00 0.00 0.00 0.00 0.00 1480.38 1.00
se_p 0.98 0.02 0.98 0.95 1.00 1687.51 1.00
sp_p 1.00 0.00 1.00 1.00 1.00 1247.72 1.00
true_p 0.00 0.00 0.00 0.00 0.00 1587.27 1.00
Number of divergences: 0
[16]:
inference_data = az.from_numpyro(mcmc)
[17]:
az.plot_forest(inference_data)
plt.show()
[18]:
az.plot_trace(inference_data)
plt.show()
[19]:
az.plot_posterior(inference_data, round_to=3,
point_estimate="mode", var_names=["true_p", "apparent_p"])
plt.show()
[20]:
az.plot_pair(inference_data, var_names=["se_p", "sp_p"], kind="kde",
colorbar=True, figsize=(10, 8), kde_kwargs={"fill_last": True})
plt.show()
Note
Please note that in this example we used only 1 MCMC chain and a few samples, on a real scenario you are advised to use multiple chains to have better diagnostics and much more samples.