Bayesian estimation of prevalence models

This section describes how to use the prevalence models for Bayesian estimation in Episuite.

See also

Bayesian modelling for COVID-19 seroprevalence studies

This is a talk that uses the same models implemented in Episuite.

Estimating SARS-CoV-2 seroprevalence and epidemiological parameters with uncertainty from serological surveys

Excellent recent articule by Larremore et al. [LFB+20] on estimation for seroprevalence studies.

Episuite models are based on Numpyro, with uses Jax.

[1]:
import numpyro
import arviz as az
from numpyro.infer import MCMC, NUTS
from numpyro.infer import init_to_value, init_to_feasible
from matplotlib import pyplot as plt
from jax import random

from episuite import prevalence

# Set 2 cores in Numpyro
numpyro.set_host_device_count(2)

True prevalence model

In this section we will estimate a true prevalence model, a model that assumes that you’re observing true prevalences (i.e. on a seroprevalence study w/ perfect testing validation properties). Leter we will improve on it by assuming imperfect testing.

[2]:
num_warmup, num_samples = 500, 2000
[3]:
# Random generator needed by jax
rng_key = random.PRNGKey(42)
rng_key, rng_key_ = random.split(rng_key)
[4]:
# Scenario: collected 4000 samples and 20 were found positive
total_observations = 4000
positive_observations = 20
[5]:
# Configure MCMC with the true_prevalence_model from Episuite
kernel = NUTS(prevalence.true_prevalence_model, init_strategy=init_to_feasible())
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=1)
[6]:
# Run MCMC
mcmc.run(rng_key_,
         obs_positive=positive_observations,
         obs_total=total_observations)
sample: 100%|██████████| 2500/2500 [00:06<00:00, 407.97it/s, 1 steps of size 1.12e+00. acc. prob=0.91]
[7]:
samples = mcmc.get_samples()
mcmc.print_summary()

                mean       std    median      5.0%     95.0%     n_eff     r_hat
    true_p      0.01      0.00      0.01      0.00      0.01    903.02      1.00

Number of divergences: 0
[8]:
inference_data = az.from_numpyro(mcmc)
[9]:
az.plot_forest(inference_data)
plt.show()
_images/prevalence_notebook_10_0.png
[10]:
az.plot_trace(inference_data)
plt.show()
_images/prevalence_notebook_11_0.png
[11]:
az.plot_posterior(inference_data, round_to=3, point_estimate="mode")
plt.show()
_images/prevalence_notebook_12_0.png

Apparent prevalence model

In this section we will estimate an apparent prevalence model, a model that incorporates the sensitiviy and specificity properties of the test validation results. We will use here a scenario where we collected samples and tested for SARS-CoV-2 and assume properties from a real test from the brand Wondfo (used in Brazil on different seroprevalence surveys).

[12]:
# Wondfo test parameters (taken from their product description from tests they made with a PCR gold standard)
#
# From a total of 42 confirmed COVID-19 positive patients: the test detected 42 positive and 0 negative.
# From a total of 172 COVID 19 negative patients: the test detected 2 positive and 170 negative.

# Specificity parameters
n_sp = 172
x_sp = 170

# Sensitivity paramters
n_se = 42
x_se = 42

# These are results from a seroprevalence study in Brazil
observed_total = 4189
observed_positive = 2
[13]:
kernel = NUTS(prevalence.apparent_prevalence_model,
              init_strategy=init_to_feasible())
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains=1)
[14]:
mcmc.run(rng_key_,
         x_se=x_se, n_se=n_se,           # Sensitivity parameters of the test used
         x_sp=x_sp, n_sp=n_sp,           # Specificity parameters of the test used
         obs_positive=observed_positive, # Positive results
         obs_total=observed_total)       # Total samples
sample: 100%|██████████| 2500/2500 [00:07<00:00, 352.70it/s, 7 steps of size 5.02e-01. acc. prob=0.91]
[15]:
mcmc.print_summary(exclude_deterministic=False)
samples_1 = mcmc.get_samples()

                  mean       std    median      5.0%     95.0%     n_eff     r_hat
  apparent_p      0.00      0.00      0.00      0.00      0.00   1480.38      1.00
        se_p      0.98      0.02      0.98      0.95      1.00   1687.51      1.00
        sp_p      1.00      0.00      1.00      1.00      1.00   1247.72      1.00
      true_p      0.00      0.00      0.00      0.00      0.00   1587.27      1.00

Number of divergences: 0
[16]:
inference_data = az.from_numpyro(mcmc)
[17]:
az.plot_forest(inference_data)
plt.show()
_images/prevalence_notebook_19_0.png
[18]:
az.plot_trace(inference_data)
plt.show()
_images/prevalence_notebook_20_0.png
[19]:
az.plot_posterior(inference_data, round_to=3,
                  point_estimate="mode", var_names=["true_p", "apparent_p"])
plt.show()
_images/prevalence_notebook_21_0.png
[20]:
az.plot_pair(inference_data, var_names=["se_p", "sp_p"], kind="kde",
             colorbar=True, figsize=(10, 8), kde_kwargs={"fill_last": True})
plt.show()
_images/prevalence_notebook_22_0.png

Note

Please note that in this example we used only 1 MCMC chain and a few samples, on a real scenario you are advised to use multiple chains to have better diagnostics and much more samples.