%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
= jnp.array([1, -15, 30])
t = 15
sigma
= jnp.linspace(-40, 40, 200)
Theta
@jax.jit
def professor_arrival_likelihood(d, theta):
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### for a given theta and sigma.
return normpdf(t[d], theta, sigma)
@memo
def complete_pooling[
_theta: Theta,=0, tau=1):
](mu
president: knows(_theta)
president: thinks[in Department, wpp=1),
department: chooses(d in Theta, wpp=normpdfjit(theta, mu, tau))
department: chooses(theta
]
president: observes_event(=professor_arrival_likelihood(department.d, department.theta))
wppreturn president[Pr[department.theta == _theta]]
= 0
mu_ = 20
tau_ = complete_pooling(mu=mu_, tau=tau_)
res
### check the size and sum of the output
# res.shape
# res.sum()
= plt.subplots()
fig, ax
0, color="black", linestyle="-")
ax.axvline(= res
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation =r"$P(\theta \mid t)$")
ax.plot(Theta, theta_posterior, label
ax.axvline(
theta_expectation, ='red',
color='--',
linestyle=(
labelr"$\operatorname{E}"
+ rf"[\theta \mid t]={theta_expectation:0.2f}$")
)= ax.set_title(r"Posterior of $\theta$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(fr"$\dot\tau$ = {tau_}", y=1)
_
plt.tight_layout() plt.show()
Multilevel models
NB the Jupyter notebook includes interactive widgets that are very useful for building an understanding of these models.
Departmental tardiness
Imagine that your university hires a new president. On her first week, she schedules a meeting with professors from three departments, { D = \{\text{G}, \text{E}, \text{M} \} }. The professor from the Department of Government ({ d{=}\text{G} }) shows up 1 minute late, the professor from the Department of English ({ d{=}\text{E} }) shows up 15 minutes before the scheduled meeting time, and the professor from the Department of Math ({ d{=}\text{M} }) shows up 30 minutes late. From this single meeting the new president updates her beliefs about when professors will show up to scheduled meetings.
But how should the president update her beliefs? It could be that departments have different relationships with punctuality and it would be useful to expect professors from the Math department to show up later than professors from the English department. It could also be the case that there is not meaningful between-department variance and she should just learn about the greater population of professors at the university.
A model that infers information about the population without differentiating subpopulations is doing “complete variance pooling”, meaning that all of the variance in the data is treated as being generated by the same distribution. I.e. the data is pooled together and treated as iid observations drawn from a single distribution.
If the president thinks of each department independently, what she learns about the tardiness of professors from one department will not change her expectation about professors from other departments. In this case, there would be “no variance pooling” between departments. The president’s observations of the three professors have no bearing on her belief about what to expect of a professor from a fourth department (e.g. Psychology).
“Partial variance pooling” models represent uncertainty at multiple levels of abstraction. Based on her observation from this meeting, the president will update her beliefs about these departments and the greater population. In other words, what she observes about the Math department might strongly update her beliefs about that department, moderately update her beliefs about the broader population, and that in turn could weakly update her beliefs about the English department.
Complete pooling
Let’s build a complete pooling model of the president’s belief about the tardiness of professors. The data observed by the president is { t = \langle t_G, t_E, t_M \rangle = \langle 1, -15, 30 \rangle }. We’ll model the president’s belief about the greater population of professors as a normal distribution centered at \theta with standard deviation \dot\sigma. Prior to meeting with anyone, the president thinks that professors will, on average, show up on time, but that some will show up early and some late. We can encode this prior belief as a normal distribution centered at zero, with some standard deviation, let’s say 20. I.e. the prior for \theta is:
\theta ~\sim~ \mathcal{N}\left( \dot{\mu}{=}0, \dot{\tau}{=}20 \right)
If she observes that professors tend to be late, then her posterior estimate of \theta will have an expected value greater than zero (i.e. { \operatorname{E}[\theta \mid t] > 0 }), and if she ends up expecting professors to be early, then { \operatorname{E}[\theta \mid t] < 0 }.
In this tutorial I’ll use the dot above to indicate that a variable is fixed, not inferred (the dot is not standard notation, just what I’m using here). In the graphical schematic, fixed parameters are depicted as bare symbols rather than nodes. Note that “fixed” means the values are not updated during inference (but you can, of course, manually tune these values or learn suitable values by minimizing some loss function in conjunction with a model fitting algorithm).
Since the model infers the distribution of \theta, the graphical schematic depicts it as an open node (i.e. a latent random variable). The data in this model are t.
We specify the likelihood as,
t_d ~\sim~ \mathcal{N}(\theta, \dot\sigma{=}15)
I.e., for whatever the true value of \theta is for the population of professors, the president thinks that when professors actually show up will follow a normal distribution, centered on the true value of \theta, with standard deviation 15.
\begin{align*} \theta ~\sim&~ \mathcal{N}\left( \dot{\mu}{=}0, \dot{\tau}{=}20 \right) \\ t_d ~\sim&~ \mathcal{N}(\theta, \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{ \text{G}, \text{E}, \text{M} \} \end{align*}
Try changing the data to t = jnp.array([-20, -10, 30])
. Notice how changing the data for one group updates the expectation of the population.
No pooling
In contrast to complete pooling, which assumes all observations are generated by a shared population-level distribution parameterized by \theta, the no pooling model treats each department as having its own independent parameter \theta_d. The key change is that instead of having a single \theta drawn from the prior distribution, we now have three independent parameters \theta_G, \theta_E, and \theta_M. The practical meaning of this change is significant:
In complete pooling, an observation of the Math professor being late would shift our expectations for all departments. In no pooling, learning that the Math professor is late only updates our beliefs about the Math department. Each department’s parameter is estimated using only data from that department.
The key change in our model specification is that we have replaced the single \theta with three parameters, \theta_d.
\begin{align*} \theta_d ~\sim&~ \mathcal{N}\left( \dot{\mu}{=}0, \dot{\tau}{=}20 \right) \\ t_d ~\sim&~ \mathcal{N}(\theta_d, \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{\text{G}, \text{E}, \text{M} \} \end{align*}
The plate (rectangular box) with label d indicates that everything inside is repeated for each department d \in D. Variables inside the plate have a unique value for each department. Variables outside the plate are shared across all departments.
To change the complete_pooling
model to the no_pooling
model we make a simple change, adding the conditional statement
is _d president: observes [department.d]
and the query variable _d
in the definition, [_d: Department, ...]
.
%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
= jnp.array([1, -15, 30])
t = 15
sigma
= jnp.linspace(-40, 40, 200)
Theta
@jax.jit
def professor_arrival_likelihood(d, theta):
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### for a given theta and sigma.
return normpdf(t[d], theta, sigma)
@memo
def no_pooling[
_d: Department,
_theta: Theta,=0, tau=1):
](mu
president: knows(_theta)
president: thinks[in Department, wpp=1),
department: chooses(d in Theta, wpp=normpdfjit(theta, mu, tau))
department: chooses(theta
]is _d ### new conditional statement
president: observes [department.d]
president: observes_event(=professor_arrival_likelihood(department.d, department.theta))
wppreturn president[Pr[department.theta == _theta]]
= 0
mu_ = 20
tau_ = no_pooling(mu=mu_, tau=tau_)
res
### check the size and sum of the output
# res.shape
# res.sum()
# res[0].sum()
# res[1].sum()
# res[2].sum()
= plt.subplots()
fig, ax
0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = res[d]
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation
ax.plot(
Theta,
theta_posterior, =(
labelrf"$p(\theta_{department_abbrev} \mid y),~ "
+ r"\operatorname{E}"
+ rf"[\theta_{department_abbrev} \mid t]={theta_expectation:0.2f}$")
)= ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(fr"$\dot\tau$ = {tau_}", y=1)
_
plt.tight_layout() plt.show()
Try changing the data to t = jnp.array([-20, -15, 30])
. Notice how changing the data for one group does not affect the expectation of the posterior of the other groups.
Previously, the complete_pooling
model returned an array of shape (200,)
with 200 being the size of the sample space of Theta
. With the addition of the conditional statement, what array shape does no_pooling
return? Make sure you understand the shape change, what each dimension reflects, and why the values sum in the ways that they do.
The key theoretical distinction is that in no pooling, we’re asserting that the timing behavior of professors in different departments is independent of the behavior of professors from other departments (i.e. when professors from one department show up for meetings carries no information about when professors from a different department will show up). This is an unrealistically strong assumption.
The no_pooling
and complete_pooling
models highlight an important practical tradeoff: no pooling can better capture differences between departments but may suffer from high variance in its estimates due to small sample sizes within each department. Complete pooling has lower variance but may miss important systematic differences between departments.
Partial pooling
In partial pooling, every observation updates beliefs across multiple levels of abstraction. The critical insight is that while departments may differ in their typical timing behavior, the departments are not statistically independent. They are influenced by common causes. To model the president’s reasoning about these multiple levels of uncertainty, we modify the no_pooling
model by adding priors over \mu and \tau.
Rather than being fixed values, \mu and \tau become latent parameters to be jointly inferred alongside \theta_d.
The higher-level latent parameters \mu and \tau characterize a “population-level” distribution from which department-specific \theta_d values are generated.
The key change to the model is that rather than { \theta_d ~\sim~ \mathcal{N}\left( \dot{\mu}{=}0, \dot{\tau}{=}20 \right) }, we have { \theta_d ~\sim~ \mathcal{N}\left(\mu, \tau \right) } and specify hyperpriors on these new latent parameters: { \mu ~\sim~ \mathcal{N}(0, \dot\sigma_{\mu}) } and { \tau ~\sim~ \text{HalfCauchy}(\dot\sigma_{\tau}) }.
Note that there are new fixed parameters: mu_scale
({ \dot\sigma_\mu}) and tau_scale
({ \dot\sigma_\tau }), but I am omitting there from the graphical schematic for the sake of visual simplicity.
\begin{align*} \mu ~\sim&~ \mathcal{N}(0, \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}\left(\mu, \tau \right) \\ t_d ~\sim&~ \mathcal{N}(\theta_d, \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{\text{G}, \text{E}, \text{M} \} \end{align*}
%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
= jnp.array([1, -15, 30])
t = 15
sigma
= jnp.linspace(-25, 25, 100) ### sample space for new hyperprior
Mu = jnp.linspace(1, 30, 100) ### sample space for new hyperprior
Tau = jnp.linspace(-40, 40, 200)
Theta
### PDF for new hyperprior
@jax.jit
def half_cauchy(x, scale=1.0):
return 2 * cauchypdf(x, 0, scale)
@jax.jit
def professor_arrival_likelihood(d, theta):
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### for a given theta and sigma.
return normpdf(t[d], theta, sigma)
@memo
def department_model[_mu: Mu, _tau: Tau](d):
department: knows(_mu, _tau)in Theta, wpp=normpdfjit(theta, _mu, _tau))
department: chooses(theta return E[ department[professor_arrival_likelihood(d, theta)] ]
@memo
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
president: knows(_mu, _tau)
president: thinks[in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
population: chooses(mu in Tau, wpp=half_cauchy(tau, tau_scale)),
population: chooses(tau
]
president: observes_event(=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
wpp
president: observes_event(=department_model[population.mu, population.tau]({Department.ENGLISH}))
wpp
president: observes_event(=department_model[population.mu, population.tau]({Department.MATH}))
wpp
return president[Pr[population.mu == _mu, population.tau == _tau]]
Let’s think about what each variable represents.
The likelihood model professor_arrival_likelihood
gives the probability that a professor from department d shows up t_d minutes early/late, under a given \theta_d and \sigma:
t_d ~\sim~ \mathcal{N}(\theta_d, \dot\sigma{=}15)
We can write the likelihood as { P(t_d \mid \theta_d, \mu, \tau ; \sigma) }. The marginal posterior { P(\theta_d \mid t) } will thus represent the president’s belief about the true value of \theta_d for a given department d.
@memo
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
obs: thinks[in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
department: chooses(mu in Theta, wpp=normpdfjit(theta, mu, tau))
department: chooses(theta
]=professor_arrival_likelihood(d, department.theta))
obs: observes_event(wpp
obs: knows(_theta)return obs[Pr[department.theta == _theta]]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8), verbose=False):
= partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale)
posterior
# Marginal over Tau (sum over Mu)
= posterior.sum(axis=0)
posterior_tau # Marginal over Mu (sum over Tau)
= posterior.sum(axis=1)
posterior_mu
= plt.subplots(3, 1, figsize=figsize)
fig, axs
= axs[0]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\mu \mid t)$")
ax.plot(Mu, posterior_mu, label= jnp.dot(Mu, posterior_mu)
mu_expectation
ax.axvline(
mu_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:0.2f}$")
label= ax.set_title(r"Posterior of $\mu$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[1]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\tau \mid t)$")
ax.plot(Tau, posterior_tau, label= jnp.dot(Tau, posterior_tau)
tau_expectation
ax.axvline(
tau_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:0.2f}$")
label= ax.set_title(r"Posterior of $\tau$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[2]
ax 0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = department_model_theta(d, mu_scale, tau_scale)
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation
ax.plot(
Theta,
theta_posterior, =(
labelrf"$P(\theta_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\theta_{department_abbrev} \mid t]={theta_expectation:0.2f}$"))
= ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
_
plt.tight_layout()
plt.show()
if verbose:
for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = department_model_theta(d, mu_scale, tau_scale)
theta_posterior = jnp.average(Theta, weights=theta_posterior)
posterior_mean = jnp.average(Theta**2, weights=theta_posterior)
posterior_second_moment = posterior_second_moment - posterior_mean**2
posterior_variance print(f"{mu_scale=}, {tau_scale=} : E[θ{department_abbrev} | t]={posterior_mean:0.3f}, Var[θ{department_abbrev} | t]={posterior_variance:0.3f}")
When mu_scale
is small (e.g. { \dot\sigma_{\mu}{=}1 }) the location of the distribution over \theta_d will be be centered near zero. When mu_scale
is large (e.g. { \dot\sigma_{\mu}{=}10 }), the center of { \mathcal{N}(\mu, \tau) } can move to new locations, resulting in posterior distributions with similar variance with expected values quite different from each other.
= [
param_list dict(mu_scale=1, tau_scale=3),
dict(mu_scale=3, tau_scale=3),
dict(mu_scale=10, tau_scale=3),
]for i_params_, params_ in enumerate(param_list):
**params_, figsize=(6, 5), verbose=True) plot_model(
mu_scale=1, tau_scale=3 : E[θG | t]=0.187, Var[θG | t]=29.778
mu_scale=1, tau_scale=3 : E[θE | t]=-2.155, Var[θE | t]=40.307
mu_scale=1, tau_scale=3 : E[θM | t]=6.092, Var[θM | t]=82.286
mu_scale=3, tau_scale=3 : E[θG | t]=0.595, Var[θG | t]=35.435
mu_scale=3, tau_scale=3 : E[θE | t]=-2.168, Var[θE | t]=46.285
mu_scale=3, tau_scale=3 : E[θM | t]=6.891, Var[θM | t]=79.738
mu_scale=10, tau_scale=3 : E[θG | t]=2.370, Var[θG | t]=61.138
mu_scale=10, tau_scale=3 : E[θE | t]=-2.297, Var[θE | t]=73.436
mu_scale=10, tau_scale=3 : E[θM | t]=10.948, Var[θM | t]=80.582
When tau_scale
is small (e.g. { \dot\sigma_{\tau}{=}1 }), the distribution { \theta_d ~\sim~ \mathcal{N}(\mu, \tau) } will be narrow, which will pressure the individual { \theta_d } to be similar to each other. When tau_scale
is large (e.g. { \dot\sigma_{\tau}{=}20 }), the { \theta_d } can diverge more freely resulting in larger posterior variance, and expected values quite different from each other.
= [
param_list dict(mu_scale=1, tau_scale=1),
dict(mu_scale=1, tau_scale=3),
dict(mu_scale=1, tau_scale=10),
dict(mu_scale=1, tau_scale=20),
]for i_params_, params_ in enumerate(param_list):
**params_, figsize=(6, 5), verbose=True) plot_model(
mu_scale=1, tau_scale=1 : E[θG | t]=0.139, Var[θG | t]=17.512
mu_scale=1, tau_scale=1 : E[θE | t]=-1.247, Var[θE | t]=24.192
mu_scale=1, tau_scale=1 : E[θM | t]=3.798, Var[θM | t]=55.836
mu_scale=1, tau_scale=3 : E[θG | t]=0.187, Var[θG | t]=29.778
mu_scale=1, tau_scale=3 : E[θE | t]=-2.155, Var[θE | t]=40.307
mu_scale=1, tau_scale=3 : E[θM | t]=6.092, Var[θM | t]=82.286
mu_scale=1, tau_scale=10 : E[θG | t]=0.298, Var[θG | t]=57.456
mu_scale=1, tau_scale=10 : E[θE | t]=-4.140, Var[θE | t]=73.415
mu_scale=1, tau_scale=10 : E[θM | t]=10.436, Var[θM | t]=116.686
mu_scale=1, tau_scale=20 : E[θG | t]=0.366, Var[θG | t]=74.551
mu_scale=1, tau_scale=20 : E[θE | t]=-5.324, Var[θE | t]=91.824
mu_scale=1, tau_scale=20 : E[θM | t]=12.698, Var[θM | t]=127.697
In this way, a large { \dot\sigma_{\mu} } and small { \dot\sigma_{\tau} } would allow { \mathcal{N}(\mu, \tau) } to move to easily move to a new location that accommodates the broader tendencies of departments and professors, but would constrain the dispersion between departments. Notice how increasing mu_scale
allows the posterior { P(\mu \mid t) } to move away from zero towards the mean of the combined data, whereas decreasing mu_scale
pulls the expectation { \operatorname{E}[\mu \mid t] } towards zero. Also notice how this effect is stronger when tau_scale
is small.
= [
param_list dict(mu_scale=1, tau_scale=1),
# dict(mu_scale=5, tau_scale=5),
# dict(mu_scale=10, tau_scale=10),
dict(mu_scale=20, tau_scale=20),
# dict(mu_scale=10, tau_scale=1),
dict(mu_scale=20, tau_scale=3),
# dict(mu_scale=1, tau_scale=10),
dict(mu_scale=3, tau_scale=20),
]for i_params_, params_ in enumerate(param_list):
**params_, figsize=(6, 5), verbose=True) plot_model(
mu_scale=1, tau_scale=1 : E[θG | t]=0.139, Var[θG | t]=17.512
mu_scale=1, tau_scale=1 : E[θE | t]=-1.247, Var[θE | t]=24.192
mu_scale=1, tau_scale=1 : E[θM | t]=3.798, Var[θM | t]=55.836
mu_scale=20, tau_scale=20 : E[θG | t]=2.517, Var[θG | t]=112.461
mu_scale=20, tau_scale=20 : E[θE | t]=-6.000, Var[θE | t]=129.325
mu_scale=20, tau_scale=20 : E[θM | t]=16.936, Var[θM | t]=112.901
mu_scale=20, tau_scale=3 : E[θG | t]=3.228, Var[θG | t]=74.198
mu_scale=20, tau_scale=3 : E[θE | t]=-2.422, Var[θE | t]=87.779
mu_scale=20, tau_scale=3 : E[θM | t]=13.039, Var[θM | t]=83.766
mu_scale=3, tau_scale=20 : E[θG | t]=0.632, Var[θG | t]=78.673
mu_scale=3, tau_scale=20 : E[θE | t]=-5.378, Var[θE | t]=96.102
mu_scale=3, tau_scale=20 : E[θM | t]=13.063, Var[θM | t]=124.248
Multiple observations with missing data
What if the president instead had a large meeting with 7 professors and the departments were not equally represented? How would she update her beliefs given unequal observations? Let’s say that she invited 3 professors from the Department of Government, 3 from the Department of English, and 1 from the Department of Math. We can explore how modulating the higher-level hyperpriors changes her posterior inference about each department.
\begin{align*} \mu ~\sim&~ \mathcal{N}(0, \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}\left(\mu, \tau \right) \\ t_{(d,i)} ~\sim&~ \mathcal{N}(\theta_d, \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{\text{G}, \text{E}, \text{M} \} \end{align*}
%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.norm import logpdf as normlogpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
###NEW
= jnp.array([
t -10, 1, 11],
[-16, -15, -14],
[30, jnp.nan, jnp.nan],
[
])
= 15
sigma
= jnp.linspace(-25, 25, 100)
Mu = jnp.linspace(1, 30, 100)
Tau = jnp.linspace(-40, 40, 200)
Theta
@jax.jit
def half_cauchy(x, scale=1.0):
return 2 * cauchypdf(x, 0, scale)
@jax.jit
def professor_arrival_likelihood(d, theta):
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### for a given theta and sigma.
return jnp.exp(jnp.nansum(normlogpdf(t[d], theta, sigma))) ###NEW
@memo
def department_model[_mu: Mu, _tau: Tau](d):
department: knows(_mu, _tau)in Theta, wpp=normpdfjit(theta, _mu, _tau))
department: chooses(theta return E[ department[professor_arrival_likelihood(d, theta)] ]
@memo
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
president: knows(_mu, _tau)
president: thinks[in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
population: chooses(mu in Tau, wpp=half_cauchy(tau, tau_scale)),
population: chooses(tau
]=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
president: observes_event(wpp=department_model[population.mu, population.tau]({Department.ENGLISH}))
president: observes_event(wpp=department_model[population.mu, population.tau]({Department.MATH}))
president: observes_event(wppreturn president[Pr[population.mu == _mu, population.tau == _tau]]
@memo
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
obs: thinks[in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
department: chooses(mu in Theta, wpp=normpdfjit(theta, mu, tau))
department: chooses(theta
]=professor_arrival_likelihood(d, department.theta))
obs: observes_event(wpp
obs: knows(_theta)return obs[Pr[department.theta == _theta]]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8)):
= partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale)
posterior
# Marginal over Tau (sum over Mu)
= posterior.sum(axis=0)
posterior_tau # Marginal over Mu (sum over Tau)
= posterior.sum(axis=1)
posterior_mu
= plt.subplots(3, 1, figsize=figsize)
fig, axs
= axs[0]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\mu \mid t)$")
ax.plot(Mu, posterior_mu, label= jnp.dot(Mu, posterior_mu)
mu_expectation
ax.axvline(
mu_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:0.2f}$")
label= ax.set_title(r"Posterior of $\mu$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[1]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\tau \mid t)$")
ax.plot(Tau, posterior_tau, label= jnp.dot(Tau, posterior_tau)
tau_expectation
ax.axvline(
tau_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:0.2f}$")
label= ax.set_title(r"Posterior of $\tau$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[2]
ax 0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = department_model_theta(d, mu_scale, tau_scale)
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation
ax.plot(
Theta,
theta_posterior, =(
labelrf"$P(\theta_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\theta_{department_abbrev} \mid t]={theta_expectation:0.2f}$"))
= ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
_
plt.tight_layout() plt.show()
= [
param_list dict(mu_scale=1, tau_scale=1),
# dict(mu_scale=5, tau_scale=5),
# dict(mu_scale=10, tau_scale=10),
dict(mu_scale=20, tau_scale=20),
# dict(mu_scale=10, tau_scale=1),
dict(mu_scale=20, tau_scale=1),
# dict(mu_scale=1, tau_scale=10),
dict(mu_scale=1, tau_scale=20),
]for i_params_, params_ in enumerate(param_list):
**params_, figsize=(6, 5)) plot_model(
When tau_scale
is small (e.g. { \dot\sigma_{\tau}{=}1 }), the distribution { \theta_d ~\sim~ \mathcal{N}(\mu, \tau) } will be narrow, which will pressure the individual { \theta_d } to be similar to each other. When tau_scale
is large (e.g. { \dot\sigma_{\tau}{=}20 }), the { \theta_d } can diverge more freely resulting in values quite different from each other. Think about how this will interact with the number of observations and the values of those observations.
When mu_scale
is small (e.g. { \dot\sigma_{\mu}{=}1 }) the location of the distribution over \theta_d will be be centered near zero. When mu_scale
is large (e.g. { \dot\sigma_{\mu}{=}20 }), the center of { \mathcal{N}(\mu, \tau) } can move to new locations.
In this way, a large { \dot\sigma_{\mu} } and small { \dot\sigma_{\tau} } would allow { \mathcal{N}(\mu, \tau) } to move to easily move to a new location that accommodates the broader tendencies of departments and professors, but would constrain the dispersion between departments. Notice how increasing mu_scale
allows the posterior { P(\mu \mid t) } to move away from zero towards the mean of the combined data, whereas decreasing mu_scale
pulls the expectation { \operatorname{E}[\mu \mid t] } towards zero. Also notice how this effect is stronger when tau_scale
is small.
Pay close attention to the interaction between mu_scale
, tau_scale
, and the observations for each department (including the expectation and variance of the observations of each department).
Multiple observations with missing data and learned variance
\begin{align*} \mu ~\sim&~ \mathcal{N}(0, \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}(\mu, \tau) \\ \sigma_d ~\sim&~ \text{HalfCauchy}(5) \\ t_{(d,i)} ~\sim&~ \mathcal{N}(\theta_d, \sigma_d) \\ d ~\in&~ \{ 0, 1, 2\} \end{align*}
%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.norm import logpdf as normlogpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
= jnp.array([
t -10, 1, 11],
[-16, -15, -14],
[30, jnp.nan, jnp.nan],
[
])
= jnp.linspace(-25, 25, 100)
Mu = jnp.linspace(1, 30, 100)
Tau = jnp.linspace(-40, 40, 200)
Theta = jnp.linspace(1, 30, 100) ###NEW
Sigma
@jax.jit
def half_cauchy(x, scale=1.0):
return 2 * cauchypdf(x, 0, scale)
@jax.jit
def professor_arrival_likelihood(d, theta, sigma): ###NEW
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### for a given theta and sigma.
return jnp.exp(jnp.nansum(normlogpdf(t[d], theta, sigma)))
@memo(cache=True)
def department_model[_mu: Mu, _tau: Tau](d):
department: knows(_mu, _tau)in Theta, wpp=normpdfjit(theta, _mu, _tau))
department: chooses(theta in Sigma, wpp=half_cauchy(sigma, 5)) ###NEW
department: chooses(sigma return E[ department[professor_arrival_likelihood(d, theta, sigma)] ] ###NEW
@memo(cache=True)
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
president: knows(_mu, _tau)
president: thinks[in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
population: chooses(mu in Tau, wpp=half_cauchy(tau, tau_scale)),
population: chooses(tau
]=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
president: observes_event(wpp=department_model[population.mu, population.tau]({Department.ENGLISH}))
president: observes_event(wpp=department_model[population.mu, population.tau]({Department.MATH}))
president: observes_event(wppreturn president[Pr[population.mu == _mu, population.tau == _tau]]
@memo(cache=True)
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
obs: thinks[in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
department: chooses(mu in Theta, wpp=normpdfjit(theta, mu, tau)),
department: chooses(theta in Sigma, wpp=half_cauchy(sigma, 5)), ###NEW
department: chooses(sigma
]=professor_arrival_likelihood(d, department.theta, department.sigma)) ###NEW
obs: observes_event(wpp
obs: knows(_theta)return obs[Pr[department.theta == _theta]]
@memo(cache=True)
def department_model_sigma[_sigma: Sigma](d, mu_scale, tau_scale): ###NEW
obs: thinks[in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
department: chooses(mu in Theta, wpp=normpdfjit(theta, mu, tau)),
department: chooses(theta in Sigma, wpp=half_cauchy(sigma, 5)),
department: chooses(sigma
]=professor_arrival_likelihood(d, department.theta, department.sigma))
obs: observes_event(wpp
obs: knows(_sigma)return obs[Pr[department.sigma == _sigma]]
Caching function
# Cache for precomputed results
= {}
cache
def compute_distributions(mu_scale, tau_scale, verbose=False):
"""Retrieve from cache or compute the JAX-based distributions."""
= (mu_scale, tau_scale)
key if key in cache:
return cache[key] # Use cached results
if verbose:
print(f"{mu_scale=}, {tau_scale=}")
# Cache results
= (
cache[key] =mu_scale, tau_scale=tau_scale).sum(axis=1),
partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale).sum(axis=0),
partial_pooling(mu_scalefor d in Department},
{d: department_model_sigma(d, mu_scale, tau_scale) for d in Department},
{d: department_model_theta(d, mu_scale, tau_scale)
)return cache[key]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8)):
= compute_distributions(mu_scale, tau_scale)
posterior_mu, posterior_tau, sigma_posteriors, theta_posteriors
= plt.subplots(4, 1, figsize=figsize)
fig, axs
= axs[0]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\mu \mid t)$")
ax.plot(Mu, posterior_mu, label= jnp.dot(Mu, posterior_mu)
mu_expectation
ax.axvline(
mu_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:0.2f}$")
label= ax.set_title(r"Posterior of $\mu$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[1]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\tau \mid t)$")
ax.plot(Tau, posterior_tau, label= jnp.dot(Tau, posterior_tau)
tau_expectation
ax.axvline(
tau_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:0.2f}$")
label= ax.set_title(r"Posterior of $\tau$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[2]
ax 0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = sigma_posteriors[d]
sigma_posterior = jnp.dot(Sigma, sigma_posterior)
sigma_expectation
ax.plot(
Sigma,
sigma_posterior, =(
labelrf"$P(\sigma_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\sigma_{department_abbrev} \mid t]={sigma_expectation:0.2f}$"))
= ax.set_xlim(-1, 20)
_ = ax.set_title(r"Posterior of $\sigma_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[3]
ax 0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = theta_posteriors[d]
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation
ax.plot(
Theta,
theta_posterior, =(
labelrf"$P(\theta_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\theta_{department_abbrev} \mid t]={theta_expectation:0.2f}$"))
= ax.set_xlim(-20, 40)
_ = ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
_
plt.tight_layout() plt.show()
= [
param_list dict(mu_scale=1, tau_scale=1),
# dict(mu_scale=5, tau_scale=5),
dict(mu_scale=10, tau_scale=10),
# dict(mu_scale=20, tau_scale=20),
dict(mu_scale=10, tau_scale=1),
# dict(mu_scale=20, tau_scale=1),
dict(mu_scale=1, tau_scale=10),
# dict(mu_scale=1, tau_scale=20),
]"all")
plt.close(for i_params_, params_ in enumerate(param_list):
**params_, figsize=(10, 8)) plot_model(
- McElreath (2016, Chapters 1, 13, 14)
- Gelman (2014, Chapter 5)
- Kruschke (2015, Chapter 9)
%reset -f
import sys
import platform
import importlib.metadata
print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())
print("\nPackages:")
for name, version in sorted(
"Name"], dist.version) for dist in importlib.metadata.distributions()),
((dist.metadata[=lambda x: x[0].lower() # Sort case-insensitively
key
):print(f"{name}=={version}")
Python: 3.13.2 (main, Feb 5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64
Packages:
annotated-types==0.7.0
anyio==4.8.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
async-lru==2.0.4
attrs==25.1.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.17.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.8
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.1
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.5.2
jaxlib==0.5.1
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
memo-lang==1.1.0
mistune==3.1.2
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.3
opt_einsum==3.4.0
optype==0.9.1
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.241126
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.6
plotly==5.24.1
pre_commit==4.1.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pyparsing==3.2.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==3.3.0
pytz==2025.1
PyYAML==6.0.2
pyzmq==26.2.1
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.9.10
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.0
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==75.8.2
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.38
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.4.0
toml==0.10.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250204
typing_extensions==4.12.2
tzdata==2025.1
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.1.2
zipp==3.21.0