%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
= jnp.array([1, -15, 30])
t = 15
sigma
= jnp.linspace(-40, 40, 200)
Theta
@jax.jit
def professor_arrival_likelihood(d, theta):
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### under the hypothesis given by theta and sigma.
return normpdf(t[d], loc=theta, scale=sigma)
@memo
def complete_pooling[
_theta: Theta,=0, tau=1):
](mu
president: knows(_theta)
president: thinks[in Theta, wpp=normpdfjit(theta, mu, tau))
department: chooses(theta
]
president: observes_event(=professor_arrival_likelihood({Department.GOVERNMENT}, department.theta))
wpp
president: observes_event(=professor_arrival_likelihood({Department.ENGLISH}, department.theta))
wpp
president: observes_event(=professor_arrival_likelihood({Department.MATH}, department.theta))
wppreturn president[Pr[department.theta == _theta]]
= 0
mu_ = 20
tau_ = complete_pooling(mu=mu_, tau=tau_)
res
### check the size and sum of the output
# res.shape
# res.sum()
= plt.subplots()
fig, ax
0, color="black", linestyle="-")
ax.axvline(= res
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation =r"$P(\theta \mid t)$")
ax.plot(Theta, theta_posterior, label
ax.axvline(
theta_expectation, ='red',
color='--',
linestyle=(
labelr"$\operatorname{E}"
+ rf"[\theta \mid t]={theta_expectation:6.2f}$")
)= ax.set_title(r"Posterior of $\theta$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(fr"$\dot\tau$ = {tau_}", y=1)
_
plt.tight_layout() plt.show()
Multilevel models
NB the Jupyter notebook includes interactive widgets that are very useful for building an understanding of these models.
Departmental punctuality
Imagine that your university hires a new president. On her first week, she schedules a meeting with professors from three departments, { D = \{ \text{G}, ~ \text{E}, ~ \text{M} \} }. The professor from the Department of Government ({ d{=}\text{G} }) shows up 1 minute late, the professor from the Department of English ({ d{=}\text{E} }) shows up 15 minutes before the scheduled meeting time, and the professor from the Department of Math ({ d{=}\text{M} }) shows up 30 minutes late. From this single meeting the new president updates her beliefs about when professors will show up to scheduled meetings.
But how should the president update her beliefs? It could be that departments have different relationships with punctuality and it would be useful to expect professors from the Math department to show up later than professors from the English department. It could also be the case that there is not meaningful between-department variance and she should just learn about the greater population of professors at the university.
A model that infers information about the population without differentiating subpopulations is doing “complete pooling”, meaning that all of the variance in the data is treated as being generated by the same distribution. I.e. the data is pooled together and treated as iid observations drawn from a single distribution.
If the president thinks of each department independently, what she learns about the punctuality of professors from one department will not change her beliefs about professors from other departments. In this case, there would be “no pooling” of information across departments. The president’s observations of the three professors would have no bearing on her belief about what to expect of a professor from a fourth department (e.g. Psychology).
“Partial pooling” models represent uncertainty at multiple levels of abstraction. Based on her observations from this meeting, the president will update her beliefs about these departments as well as the greater population. In other words, what she observes about the Math department might strongly update her beliefs about that department, moderately update her beliefs about the broader population, and that in turn could weakly update her beliefs about the English department.
Complete pooling
Let’s build a complete pooling model of the president’s belief about the punctuality of professors. The data observed by the president are { t = \langle t_G, t_E, t_M \rangle = \langle 1, -15, 30 \rangle }. We’ll model the president’s belief about the greater population of professors as a normal distribution centered at \theta with standard deviation \dot\sigma. Prior to meeting with anyone, the president thinks that professors will, on average, show up on time, but that some will show up early and some late. We can encode this prior belief as a normal distribution centered at zero, with some standard deviation, let’s say 20. I.e. the prior for \theta is:
\theta ~\sim~ \mathcal{N}\left( \dot{\mu}{=}0, ~ \dot{\tau}{=}20 \right)
If she observes that professors tend to be late, then her posterior estimate of \theta will have an expected value greater than zero (i.e. { \operatorname{E}[\theta \mid t] > 0 }), and if she ends up expecting professors to be early, then { \operatorname{E}[\theta \mid t] < 0 }.
In this tutorial I’ll use a dot above to indicate that a variable is fixed, not inferred. (N.B. the dot is not standard notation, just what I’m using here.) In the graphical schematic, fixed parameters are depicted as bare symbols rather than nodes. Note that “fixed” means the values are not updated during inference (but you can, of course, manually tune these values or learn suitable values by minimizing some loss function in conjunction with a model fitting algorithm).
Since the model infers the distribution of \theta, the graphical schematic depicts it as an open node (i.e. a latent random variable). The data in this model are t.
We specify the likelihood as,
t_d ~\sim~ \mathcal{N}(\theta, ~ \dot\sigma{=}15)
I.e., for whatever the true value of \theta is for the population of professors, the president thinks that when professors actually show up will follow a normal distribution, centered on the true value of \theta, with standard deviation 15.
\begin{align*} \theta ~\sim&~ \mathcal{N}\left( \dot{\mu}{=}0, ~ \dot{\tau}{=}20 \right) \\ t_d ~\sim&~ \mathcal{N}(\theta, ~ \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{ \text{G}, ~ \text{E}, ~ \text{M} \} \end{align*}
Try changing the data to t = jnp.array([-20, -10, 30])
. Notice how changing the data for one group updates the expectation of the population.
No pooling
In contrast to complete pooling, which assumes all observations are generated by a shared population-level distribution parameterized by \theta, the no pooling model treats each department as having its own independent parameter \theta_d. The key change is that instead of having a single \theta drawn from the prior distribution, we now have three independent parameters \theta_G, \theta_E, and \theta_M. The practical meaning of this change is significant:
In complete pooling, an observation of the Math professor being late would shift our expectations for all departments. In no pooling, learning that the Math professor is late only updates our beliefs about the Math department. Each department’s parameter is estimated using only data from that department.
The key change in our model specification is that we have replaced the single \theta with three parameters, \theta_d.
\begin{align*} \theta_d ~\sim&~ \mathcal{N}\left( \dot{\mu}{=}0, ~ \dot{\tau}{=}20 \right) \\ t_d ~\sim&~ \mathcal{N}(\theta_d, ~ \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{ \text{G}, ~ \text{E}, ~ \text{M} \} \end{align*}
The plate (rectangular box) with label d indicates that everything inside is repeated for each department d \in D. Variables inside the plate have a unique value for each department. Variables outside the plate are shared across all departments.
To change the complete_pooling
model to the no_pooling
model we make a simple change, adding the conditional statement
is _d president: observes [department.d]
and the query variable _d
in the definition, [_d: Department, ...]
.
%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
= jnp.array([1, -15, 30])
t = 15
sigma
= jnp.linspace(-40, 40, 200)
Theta
@jax.jit
def professor_arrival_likelihood(d, theta):
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### under the hypothesis given by theta and sigma.
return normpdf(t[d], loc=theta, scale=sigma)
@memo
def no_pooling[
_d: Department,
_theta: Theta,=0, tau=1):
](mu
president: knows(_theta)
president: thinks[in Department, wpp=1),
department: chooses(d in Theta, wpp=normpdfjit(theta, mu, tau))
department: chooses(theta
]is _d ### new conditional statement
president: observes [department.d]
president: observes_event(=professor_arrival_likelihood(department.d, department.theta))
wppreturn president[Pr[department.theta == _theta]]
= 0
mu_ = 20
tau_ = no_pooling(mu=mu_, tau=tau_)
res
### check the size and sum of the output
# res.shape
# res.sum()
# res[0].sum()
# res[1].sum()
# res[2].sum()
= plt.subplots()
fig, ax
0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = res[d]
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation
ax.plot(
Theta,
theta_posterior, =(
labelrf"$p(\theta_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\theta_{department_abbrev} \mid t]={theta_expectation:6.2f}$")
)= ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(fr"$\dot\tau$ = {tau_}", y=1)
_
plt.tight_layout() plt.show()
Try changing the data to t = jnp.array([-20, -15, 30])
. Notice how changing the data for one group does not affect the expectation of the posterior of the other groups.
Previously, the complete_pooling
model returned an array of shape (200,)
with 200 being the size of the sample space of Theta
. With the addition of the conditional statement, what array shape does no_pooling
return? Make sure you understand the shape change, what each dimension reflects, and why the values sum in the ways that they do.
The key theoretical distinction is that in no pooling, we’re asserting that the timing behavior of professors in different departments is independent of the behavior of professors from other departments (i.e. when professors from one department show up for meetings carries no information about when professors from a different department will show up). This is an unrealistically strong assumption.
The no_pooling
and complete_pooling
models highlight an important practical tradeoff: no pooling can better capture differences between departments but may suffer from high variance in its estimates due to small sample sizes within each department. Complete pooling has lower variance but may miss important systematic differences between departments.
Partial pooling
In partial pooling, every observation updates beliefs across multiple levels of abstraction. The critical insight is that while departments may differ in their typical timing behavior, the departments are not statistically independent. They are influenced by common causes. To model the president’s reasoning about these multiple levels of uncertainty, we modify the no_pooling
model by adding priors over \mu and \tau.
Rather than being fixed values, \mu and \tau become latent parameters to be jointly inferred alongside \theta_d.
The higher-level latent parameters \mu and \tau characterize a “population-level” distribution from which department-specific \theta_d values are generated.
The key change to the model is that rather than { \theta_d ~\sim~ \mathcal{N}\left( \dot{\mu}{=}0, ~ \dot{\tau}{=}20 \right) }, we have { \theta_d ~\sim~ \mathcal{N}\left(\mu, ~ \tau \right) } and specify hyperpriors on these new latent parameters: { \mu ~\sim~ \mathcal{N}(0, ~ \dot\sigma_{\mu}) } and { \tau ~\sim~ \text{HalfCauchy}(\dot\sigma_{\tau}) }.
Note that there are new fixed parameters: mu_scale
({ \dot\sigma_\mu}) and tau_scale
({ \dot\sigma_\tau }), but I am omitting there from the graphical schematic for the sake of visual simplicity.
\begin{align*} \mu ~\sim&~ \mathcal{N}(0, ~ \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}\left(\mu, ~ \tau \right) \\ t_d ~\sim&~ \mathcal{N}(\theta_d, ~ \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{ \text{G}, ~ \text{E}, ~ \text{M} \} \end{align*}
%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
= jnp.array([1, -15, 30])
t = 15
sigma
= jnp.linspace(-25, 25, 100) ### sample space for new hyperprior
Mu = jnp.linspace(1, 30, 100) ### sample space for new hyperprior
Tau = jnp.linspace(-40, 40, 200)
Theta
### PDF for new hyperprior
@jax.jit
def half_cauchy(x, scale=1.0):
return 2 * cauchypdf(x, 0, scale)
@jax.jit
def professor_arrival_likelihood(d, theta):
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### under the hypothesis given by theta and sigma.
return normpdf(t[d], loc=theta, scale=sigma)
@memo
def department_model[_mu: Mu, _tau: Tau](d):
department: knows(_mu, _tau)in Theta, wpp=normpdfjit(theta, _mu, _tau))
department: chooses(theta return E[ department[professor_arrival_likelihood(d, theta)] ]
@memo
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
president: knows(_mu, _tau)
president: thinks[in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
population: chooses(mu in Tau, wpp=half_cauchy(tau, tau_scale)),
population: chooses(tau
]
president: observes_event(=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
wpp
president: observes_event(=department_model[population.mu, population.tau]({Department.ENGLISH}))
wpp
president: observes_event(=department_model[population.mu, population.tau]({Department.MATH}))
wpp
return president[Pr[population.mu == _mu, population.tau == _tau]]
Let’s think about what each variable represents.
The likelihood model professor_arrival_likelihood
gives the probability that a professor from department d shows up t_d minutes early/late, under a given \theta_d and \sigma:
t_d ~\sim~ \mathcal{N}(\theta_d, ~ \dot\sigma{=}15)
We can write the likelihood as { P(t_d \mid \theta_d, ~ \mu, ~\tau ~;~ \sigma) }. The marginal posterior { P(\theta_d \mid t) } will thus represent the president’s belief about the true value of \theta_d for a given department d.
@memo
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
obs: thinks[in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
department: chooses(mu in Theta, wpp=normpdfjit(theta, mu, tau))
department: chooses(theta
]=professor_arrival_likelihood(d, department.theta))
obs: observes_event(wpp
obs: knows(_theta)return obs[Pr[department.theta == _theta]]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8), verbose=False):
= partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale)
posterior
# Marginal over Tau (sum over Mu)
= posterior.sum(axis=0)
posterior_tau # Marginal over Mu (sum over Tau)
= posterior.sum(axis=1)
posterior_mu
= plt.subplots(3, 1, figsize=figsize)
fig, axs
= axs[0]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\mu \mid t)$")
ax.plot(Mu, posterior_mu, label= jnp.dot(Mu, posterior_mu)
mu_expectation
ax.axvline(
mu_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:6.2f}$")
label= ax.set_title(r"Posterior of $\mu$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[1]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\tau \mid t)$")
ax.plot(Tau, posterior_tau, label= jnp.dot(Tau, posterior_tau)
tau_expectation
ax.axvline(
tau_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:6.2f}$")
label= ax.set_title(r"Posterior of $\tau$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[2]
ax 0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = department_model_theta(d, mu_scale, tau_scale)
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation
ax.plot(
Theta,
theta_posterior, =(
labelrf"$P(\theta_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\theta_{department_abbrev} \mid t]={theta_expectation:6.2f}$"))
= ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
_
plt.tight_layout()
plt.show()
if verbose:
for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = department_model_theta(d, mu_scale, tau_scale)
theta_posterior = jnp.average(Theta, weights=theta_posterior)
posterior_mean = jnp.average(Theta**2, weights=theta_posterior)
posterior_second_moment = posterior_second_moment - posterior_mean**2
posterior_variance print(f"{mu_scale=}, {tau_scale=} : E[θ{department_abbrev} | t] = {posterior_mean:7.3f} , Var[θ{department_abbrev} | t] = {posterior_variance:7.3f}")
When mu_scale
is small (e.g. { \dot\sigma_{\mu}{=}1 }) the location of the distribution over \theta_d will be centered near zero. When mu_scale
is large (e.g. { \dot\sigma_{\mu}{=}10 }), the center of { \mathcal{N}(\mu, ~ \tau) } can move to new locations with greater freedom. This results in a more flexible population-level mean, allowing the model to adapt to data that suggests professors generally arrive early or late across all departments. With a large mu_scale
, the model places less prior constraint on where the overall population mean should be, letting the data have more influence on determining this parameter.
= [
param_list dict(mu_scale=1, tau_scale=3),
dict(mu_scale=3, tau_scale=3),
dict(mu_scale=10, tau_scale=3),
]for i_params_, params_ in enumerate(param_list):
**params_, figsize=(6, 5), verbose=True) plot_model(
mu_scale=1, tau_scale=3 : E[θG | t] = 0.187 , Var[θG | t] = 29.778
mu_scale=1, tau_scale=3 : E[θE | t] = -2.155 , Var[θE | t] = 40.307
mu_scale=1, tau_scale=3 : E[θM | t] = 6.092 , Var[θM | t] = 82.286
mu_scale=3, tau_scale=3 : E[θG | t] = 0.595 , Var[θG | t] = 35.435
mu_scale=3, tau_scale=3 : E[θE | t] = -2.168 , Var[θE | t] = 46.285
mu_scale=3, tau_scale=3 : E[θM | t] = 6.891 , Var[θM | t] = 79.738
mu_scale=10, tau_scale=3 : E[θG | t] = 2.370 , Var[θG | t] = 61.138
mu_scale=10, tau_scale=3 : E[θE | t] = -2.297 , Var[θE | t] = 73.436
mu_scale=10, tau_scale=3 : E[θM | t] = 10.948 , Var[θM | t] = 80.582
When tau_scale
is small (e.g. { \dot\sigma_{\tau}{=}1 }), the distribution { \theta_d ~\sim~ \mathcal{N}(\mu, ~ \tau) } will be narrow, which will result in strong shrinkage of department-specific means toward the overall population mean. This effectively reduces the differences between departments, making the model behave more like the complete pooling model. The practical interpretation is that the model assumes departments are relatively homogeneous in their timing behavior.
When tau_scale
is large (e.g. { \dot\sigma_{\tau}{=}20 }), the { \theta_d } can diverge more freely from the population mean, resulting in weaker shrinkage effects. Department-specific estimates will be closer to their empirical means and less influenced by data from other departments. This allows the model to capture more heterogeneity between departments, making it behave more like the no pooling model while still maintaining some information sharing.
= [
param_list dict(mu_scale=1, tau_scale=1),
dict(mu_scale=1, tau_scale=3),
dict(mu_scale=1, tau_scale=10),
dict(mu_scale=1, tau_scale=20),
]for i_params_, params_ in enumerate(param_list):
**params_, figsize=(6, 5), verbose=True) plot_model(
mu_scale=1, tau_scale=1 : E[θG | t] = 0.139 , Var[θG | t] = 17.512
mu_scale=1, tau_scale=1 : E[θE | t] = -1.247 , Var[θE | t] = 24.192
mu_scale=1, tau_scale=1 : E[θM | t] = 3.798 , Var[θM | t] = 55.836
mu_scale=1, tau_scale=3 : E[θG | t] = 0.187 , Var[θG | t] = 29.778
mu_scale=1, tau_scale=3 : E[θE | t] = -2.155 , Var[θE | t] = 40.307
mu_scale=1, tau_scale=3 : E[θM | t] = 6.092 , Var[θM | t] = 82.286
mu_scale=1, tau_scale=10 : E[θG | t] = 0.298 , Var[θG | t] = 57.456
mu_scale=1, tau_scale=10 : E[θE | t] = -4.140 , Var[θE | t] = 73.415
mu_scale=1, tau_scale=10 : E[θM | t] = 10.436 , Var[θM | t] = 116.686
mu_scale=1, tau_scale=20 : E[θG | t] = 0.366 , Var[θG | t] = 74.551
mu_scale=1, tau_scale=20 : E[θE | t] = -5.324 , Var[θE | t] = 91.824
mu_scale=1, tau_scale=20 : E[θM | t] = 12.698 , Var[θM | t] = 127.697
In this way, a large { \dot\sigma_{\mu} } and small { \dot\sigma_{\tau} } would create a model that assumes departments are relatively homogeneous (due to small \tau) but is very flexible about where the overall population mean might be (due to large \mu scale). This combination would result in strong shrinkage of department-specific estimates toward a population mean that is itself quite adaptable to the data. Conversely, a small { \dot\sigma_{\mu} } and large { \dot\sigma_{\tau} } would create a model with a strong prior belief the population mean is near zero, but allows substantial variation between departments.
= [
param_list dict(mu_scale=1, tau_scale=1),
# dict(mu_scale=5, tau_scale=5),
# dict(mu_scale=10, tau_scale=10),
dict(mu_scale=20, tau_scale=20),
# dict(mu_scale=10, tau_scale=1),
dict(mu_scale=20, tau_scale=3),
# dict(mu_scale=1, tau_scale=10),
dict(mu_scale=3, tau_scale=20),
]for i_params_, params_ in enumerate(param_list):
**params_, figsize=(6, 5), verbose=True) plot_model(
mu_scale=1, tau_scale=1 : E[θG | t] = 0.139 , Var[θG | t] = 17.512
mu_scale=1, tau_scale=1 : E[θE | t] = -1.247 , Var[θE | t] = 24.192
mu_scale=1, tau_scale=1 : E[θM | t] = 3.798 , Var[θM | t] = 55.836
mu_scale=20, tau_scale=20 : E[θG | t] = 2.517 , Var[θG | t] = 112.461
mu_scale=20, tau_scale=20 : E[θE | t] = -6.000 , Var[θE | t] = 129.325
mu_scale=20, tau_scale=20 : E[θM | t] = 16.936 , Var[θM | t] = 112.901
mu_scale=20, tau_scale=3 : E[θG | t] = 3.228 , Var[θG | t] = 74.198
mu_scale=20, tau_scale=3 : E[θE | t] = -2.422 , Var[θE | t] = 87.779
mu_scale=20, tau_scale=3 : E[θM | t] = 13.039 , Var[θM | t] = 83.766
mu_scale=3, tau_scale=20 : E[θG | t] = 0.632 , Var[θG | t] = 78.673
mu_scale=3, tau_scale=20 : E[θE | t] = -5.378 , Var[θE | t] = 96.102
mu_scale=3, tau_scale=20 : E[θM | t] = 13.063 , Var[θM | t] = 124.248
Multiple observations (with missing data)
What if the president instead had a large meeting with 7 professors and the departments were not equally represented? How would she update her beliefs given unequal observations? Let’s say that she invited 3 professors from the Department of Government, 3 from the Department of English, and 1 from the Department of Math. We can explore how modulating the higher-level hyperpriors changes her posterior inference about each department.
\begin{align*} \mu ~\sim&~ \mathcal{N}(0, ~ \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}\left(\mu, ~ \tau \right) \\ t_{d,i} ~\sim&~ \mathcal{N}(\theta_d, ~ \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{ \text{G}, ~ \text{E}, ~ \text{M} \} \end{align*}
In this diagram we add a new plate, i, which is nested in plate d, to index the number of observation of each department.
%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.norm import logpdf as normlogpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
###NEW
= jnp.array([
t -10, 1, 11],
[-16, -15, -14],
[30, jnp.nan, jnp.nan],
[
])
= 15
sigma
= jnp.linspace(-25, 25, 100)
Mu = jnp.linspace(1, 30, 100)
Tau = jnp.linspace(-40, 40, 200)
Theta
@jax.jit
def half_cauchy(x, scale=1.0):
return 2 * cauchypdf(x, 0, scale)
@jax.jit
def professor_arrival_likelihood(d, theta):
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### under the hypothesis given by theta and sigma.
return jnp.exp(jnp.nansum(normlogpdf(t[d], theta, sigma))) ###NEW
@memo
def department_model[_mu: Mu, _tau: Tau](d):
department: knows(_mu, _tau)in Theta, wpp=normpdfjit(theta, _mu, _tau))
department: chooses(theta return E[ department[professor_arrival_likelihood(d, theta)] ]
@memo
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
president: knows(_mu, _tau)
president: thinks[in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
population: chooses(mu in Tau, wpp=half_cauchy(tau, tau_scale)),
population: chooses(tau
]=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
president: observes_event(wpp=department_model[population.mu, population.tau]({Department.ENGLISH}))
president: observes_event(wpp=department_model[population.mu, population.tau]({Department.MATH}))
president: observes_event(wppreturn president[Pr[population.mu == _mu, population.tau == _tau]]
@memo
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
obs: thinks[in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
department: chooses(mu in Theta, wpp=normpdfjit(theta, mu, tau))
department: chooses(theta
]=professor_arrival_likelihood(d, department.theta))
obs: observes_event(wpp
obs: knows(_theta)return obs[Pr[department.theta == _theta]]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8)):
= partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale)
posterior
# Marginal over Tau (sum over Mu)
= posterior.sum(axis=0)
posterior_tau # Marginal over Mu (sum over Tau)
= posterior.sum(axis=1)
posterior_mu
= plt.subplots(3, 1, figsize=figsize)
fig, axs
= axs[0]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\mu \mid t)$")
ax.plot(Mu, posterior_mu, label= jnp.dot(Mu, posterior_mu)
mu_expectation
ax.axvline(
mu_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:6.2f}$")
label= ax.set_title(r"Posterior of $\mu$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[1]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\tau \mid t)$")
ax.plot(Tau, posterior_tau, label= jnp.dot(Tau, posterior_tau)
tau_expectation
ax.axvline(
tau_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:6.2f}$")
label= ax.set_title(r"Posterior of $\tau$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[2]
ax 0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = department_model_theta(d, mu_scale, tau_scale)
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation
ax.plot(
Theta,
theta_posterior, =(
labelrf"$P(\theta_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\theta_{department_abbrev} \mid t]={theta_expectation:6.2f}$"))
= ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
_
plt.tight_layout() plt.show()
= [
param_list dict(mu_scale=1, tau_scale=1),
# dict(mu_scale=5, tau_scale=5),
# dict(mu_scale=10, tau_scale=10),
dict(mu_scale=20, tau_scale=20),
# dict(mu_scale=10, tau_scale=1),
dict(mu_scale=20, tau_scale=1),
# dict(mu_scale=1, tau_scale=10),
dict(mu_scale=1, tau_scale=20),
]for i_params_, params_ in enumerate(param_list):
**params_, figsize=(6, 5)) plot_model(
When tau_scale
is small (e.g. { \dot\sigma_{\tau}{=}1 }), the distribution { \theta_d ~\sim~ \mathcal{N}(\mu, ~ \tau) } will be narrow, which will pressure the individual { \theta_d } to be similar to each other. When tau_scale
is large (e.g. { \dot\sigma_{\tau}{=}20 }), the { \theta_d } can diverge more freely resulting in values quite different from each other. Think about how this will interact with the number of observations and the values of those observations.
When mu_scale
is small (e.g. { \dot\sigma_{\mu}{=}1 }) the location of the distribution over \theta_d will be be centered near zero. When mu_scale
is large (e.g. { \dot\sigma_{\mu}{=}20 }), the center of { \mathcal{N}(\mu, ~ \tau) } can move to new locations.
In this way, a large { \dot\sigma_{\mu} } and small { \dot\sigma_{\tau} } would allow { \mathcal{N}(\mu, ~ \tau) } to move to easily move to a new location that accommodates the broader tendencies of departments and professors, but would constrain the dispersion between departments. Notice how increasing mu_scale
allows the posterior { P(\mu \mid t) } to move away from zero towards the mean of the combined data, whereas decreasing mu_scale
pulls the expectation { \operatorname{E}[\mu \mid t] } towards zero. Also notice how this effect is stronger when tau_scale
is small.
Pay close attention to the interaction between mu_scale
, tau_scale
, and the observations for each department (including the expectation and variance of the observations of each department).
Multiple observations with missing data and learned variance
\begin{align*} \mu ~\sim&~ \mathcal{N}(0, ~ \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}(\mu, ~ \tau) \\ \sigma_d ~\sim&~ \text{HalfCauchy}(5) \\ t_{d,i} ~\sim&~ \mathcal{N}(\theta_d, ~ \sigma_d) \end{align*}
%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.norm import logpdf as normlogpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt
= jax.jit(normpdf)
normpdfjit
class Department(IntEnum):
= 0
GOVERNMENT = 1
ENGLISH = 2
MATH
= jnp.array([
t -10, 1, 11],
[-16, -15, -14],
[30, jnp.nan, jnp.nan],
[
])
= jnp.linspace(-25, 25, 100)
Mu = jnp.linspace(1, 30, 100)
Tau = jnp.linspace(-40, 40, 200)
Theta = jnp.linspace(1, 30, 100) ###NEW
Sigma
@jax.jit
def half_cauchy(x, scale=1.0):
return 2 * cauchypdf(x, 0, scale)
@jax.jit
def professor_arrival_likelihood(d, theta, sigma): ###NEW
### likelihood of a professor from department d
### showing up t_d minutes early/late,
### under the hypothesis given by theta and sigma.
return jnp.exp(jnp.nansum(normlogpdf(t[d], theta, sigma)))
@memo(cache=True)
def department_model[_mu: Mu, _tau: Tau](d):
department: knows(_mu, _tau)in Theta, wpp=normpdfjit(theta, _mu, _tau))
department: chooses(theta in Sigma, wpp=half_cauchy(sigma, 5)) ###NEW
department: chooses(sigma return E[ department[professor_arrival_likelihood(d, theta, sigma)] ] ###NEW
@memo(cache=True)
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
president: knows(_mu, _tau)
president: thinks[in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
population: chooses(mu in Tau, wpp=half_cauchy(tau, tau_scale)),
population: chooses(tau
]=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
president: observes_event(wpp=department_model[population.mu, population.tau]({Department.ENGLISH}))
president: observes_event(wpp=department_model[population.mu, population.tau]({Department.MATH}))
president: observes_event(wppreturn president[Pr[population.mu == _mu, population.tau == _tau]]
@memo(cache=True)
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
obs: thinks[in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
department: chooses(mu in Theta, wpp=normpdfjit(theta, mu, tau)),
department: chooses(theta in Sigma, wpp=half_cauchy(sigma, 5)), ###NEW
department: chooses(sigma
]=professor_arrival_likelihood(d, department.theta, department.sigma)) ###NEW
obs: observes_event(wpp
obs: knows(_theta)return obs[Pr[department.theta == _theta]]
@memo(cache=True)
def department_model_sigma[_sigma: Sigma](d, mu_scale, tau_scale): ###NEW
obs: thinks[in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
department: chooses(mu in Theta, wpp=normpdfjit(theta, mu, tau)),
department: chooses(theta in Sigma, wpp=half_cauchy(sigma, 5)),
department: chooses(sigma
]=professor_arrival_likelihood(d, department.theta, department.sigma))
obs: observes_event(wpp
obs: knows(_sigma)return obs[Pr[department.sigma == _sigma]]
Caching function
# Cache for precomputed results
= {}
cache
def compute_distributions(mu_scale, tau_scale, verbose=False):
"""Retrieve from cache or compute the JAX-based distributions."""
= (mu_scale, tau_scale)
key if key in cache:
return cache[key] # Use cached results
if verbose:
print(f"{mu_scale=}, {tau_scale=}")
# Cache results
= (
cache[key] =mu_scale, tau_scale=tau_scale).sum(axis=1),
partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale).sum(axis=0),
partial_pooling(mu_scalefor d in Department},
{d: department_model_sigma(d, mu_scale, tau_scale) for d in Department},
{d: department_model_theta(d, mu_scale, tau_scale)
)return cache[key]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8)):
= compute_distributions(mu_scale, tau_scale)
posterior_mu, posterior_tau, sigma_posteriors, theta_posteriors
= plt.subplots(4, 1, figsize=figsize)
fig, axs
= axs[0]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\mu \mid t)$")
ax.plot(Mu, posterior_mu, label= jnp.dot(Mu, posterior_mu)
mu_expectation
ax.axvline(
mu_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:6.2f}$")
label= ax.set_title(r"Posterior of $\mu$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[1]
ax 0, color="black", linestyle="-")
ax.axvline(=r"$P(\tau \mid t)$")
ax.plot(Tau, posterior_tau, label= jnp.dot(Tau, posterior_tau)
tau_expectation
ax.axvline(
tau_expectation, ='red',
color='--',
linestyle=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:6.2f}$")
label= ax.set_title(r"Posterior of $\tau$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[2]
ax 0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = sigma_posteriors[d]
sigma_posterior = jnp.dot(Sigma, sigma_posterior)
sigma_expectation
ax.plot(
Sigma,
sigma_posterior, =(
labelrf"$P(\sigma_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\sigma_{department_abbrev} \mid t]={sigma_expectation:6.2f}$"))
= ax.set_xlim(-1, 20)
_ = ax.set_title(r"Posterior of $\sigma_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= axs[3]
ax 0, color="black", linestyle="-")
ax.axvline(for d in Department:
= d.name
department_name = department_name[0]
department_abbrev = theta_posteriors[d]
theta_posterior = jnp.dot(Theta, theta_posterior)
theta_expectation
ax.plot(
Theta,
theta_posterior, =(
labelrf"$P(\theta_{department_abbrev} \mid t),~ "
+ r"\operatorname{E}"
+ rf"[\theta_{department_abbrev} \mid t]={theta_expectation:6.2f}$"))
= ax.set_xlim(-20, 40)
_ = ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')
_
= plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
_
plt.tight_layout() plt.show()
= [
param_list dict(mu_scale=1, tau_scale=1),
# dict(mu_scale=5, tau_scale=5),
dict(mu_scale=10, tau_scale=10),
# dict(mu_scale=20, tau_scale=20),
dict(mu_scale=10, tau_scale=1),
# dict(mu_scale=20, tau_scale=1),
dict(mu_scale=1, tau_scale=10),
# dict(mu_scale=1, tau_scale=20),
]"all")
plt.close(for i_params_, params_ in enumerate(param_list):
**params_, figsize=(10, 8)) plot_model(
Models contrasted
\begin{align*} \\ \\ (\mu \quad&~ \text{is fixed}) \\ (\tau \quad&~ \text{is fixed}) \\ (\sigma \quad&~ \text{is fixed}) \\ \theta ~\sim&~ \mathcal{N}( \mu, ~ \tau ) \\ t_d ~\sim&~ \mathcal{N}(\theta, ~ \sigma) \end{align*}
\begin{align*} \\ \\ (\mu \quad&~ \text{is fixed}) \\ (\tau \quad&~ \text{is fixed}) \\ (\sigma \quad&~ \text{is fixed}) \\ \theta_d ~\sim&~ \mathcal{N}( \mu, ~ \tau ) \\ t_d ~\sim&~ \mathcal{N}(\theta_d, ~ \sigma) \end{align*}
\begin{align*} (\sigma_{\mu} \quad&~ \text{is fixed}) \\ (\sigma_{\tau} \quad&~ \text{is fixed}) \\ \mu ~\sim&~ \mathcal{N}(0, ~ \sigma_{\mu}) \\ \tau ~\sim&~ \text{Cauchy}(\sigma_{\tau}) \\ (\sigma \quad&~ \text{is fixed}) \\ \theta_d ~\sim&~ \mathcal{N}(\mu, ~ \tau ) \\ t_d ~\sim&~ \mathcal{N}(\theta_d, ~ \sigma) \end{align*}
\begin{align*} (\sigma_{\mu} \quad&~ \text{is fixed}) \\ (\sigma_{\tau} \quad&~ \text{is fixed}) \\ \mu ~\sim&~ \mathcal{N}(0, ~ \sigma_{\mu}) \\ \tau ~\sim&~ \text{Cauchy}(\sigma_{\tau}) \\ (\sigma \quad&~ \text{is fixed}) \\ \theta_d ~\sim&~ \mathcal{N}(\mu, ~ \tau ) \\ t_{d,i} ~\sim&~ \mathcal{N}(\theta_d, ~ \sigma) \end{align*}
\begin{align*} (\sigma_{\mu} \quad&~ \text{is fixed}) \\ (\sigma_{\tau} \quad&~ \text{is fixed}) \\ \mu ~\sim&~ \mathcal{N}(0, ~ \sigma_{\mu}) \\ \tau ~\sim&~ \text{Cauchy}(\sigma_{\tau}) \\ \sigma_d ~\sim&~ \text{Cauchy}(5) \\ \theta_d ~\sim&~ \mathcal{N}(\mu, ~ \tau) \\ t_{d,i} ~\sim&~ \mathcal{N}(\theta_d, ~ \sigma_d) \end{align*}
- McElreath (2016, Chapters 1, 13, 14)
- Gelman (2014, Chapter 5)
- Kruschke (2015, Chapter 9)
%reset -f
import sys
import platform
import importlib.metadata
print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())
print("\nPackages:")
for name, version in sorted(
"Name"], dist.version) for dist in importlib.metadata.distributions()),
((dist.metadata[=lambda x: x[0].lower() # Sort case-insensitively
key
):print(f"{name}=={version}")
Python: 3.13.3 | packaged by conda-forge | (main, Apr 14 2025, 20:44:30) [Clang 18.1.8 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64
Packages:
annotated-types==0.7.0
anyio==4.9.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
astroid==3.3.10
asttokens==3.0.0
async-lru==2.0.5
attrs==25.3.0
babel==2.17.0
beautifulsoup4==4.13.4
bleach==6.2.0
certifi==2025.4.26
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.2
click==8.2.0
comm==0.2.2
contourpy==1.3.2
cycler==0.12.1
debugpy==1.8.14
decorator==5.2.1
defusedxml==0.7.1
dill==0.4.0
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.18.0
fonttools==4.58.0
fqdn==1.5.1
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
identify==2.6.10
idna==3.10
importlib_metadata==8.7.0
ipykernel==6.29.5
ipython==9.2.0
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.7
isoduration==20.11.0
isort==6.0.1
jax==0.6.0
jaxlib==0.6.0
jedi==0.19.2
Jinja2==3.1.6
joblib==1.5.0
json5==0.12.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2025.4.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.16.0
jupyter_server_terminals==0.5.3
jupyterlab==4.4.2
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.15
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.3
matplotlib-inline==0.1.7
mccabe==0.7.0
memo-lang==1.2.0
mistune==3.1.3
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.6
opt_einsum==3.4.0
optype==0.9.3
overrides==7.7.0
packaging==25.0
pandas==2.2.3
pandas-stubs==2.2.3.250308
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.2.1
platformdirs==4.3.8
plotly==5.24.1
pre_commit==4.2.0
prometheus_client==0.22.0
prompt_toolkit==3.0.51
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.11.4
pydantic_core==2.33.2
Pygments==2.19.1
pygraphviz==1.14
pylint==3.3.7
pyparsing==3.2.3
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-json-logger==3.3.0
pytz==2025.2
PyYAML==6.0.2
pyzmq==26.4.0
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.25.0
ruff==0.11.10
scikit-learn==1.6.1
scipy==1.15.3
scipy-stubs==1.15.3.0
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==80.7.1
six==1.17.0
sniffio==1.3.1
soupsieve==2.7
SQLAlchemy==2.0.41
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.1.2
terminado==0.18.1
threadpoolctl==3.6.0
tinycss2==1.4.0
toml==0.10.2
tomlkit==0.13.2
tornado==6.5
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20250516
types-pytz==2025.2.0.20250516
typing-inspection==0.4.0
typing_extensions==4.13.2
tzdata==2025.2
uri-template==1.3.0
urllib3==2.4.0
virtualenv==20.31.2
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.14
xarray==2025.4.0
zipp==3.21.0