Multilevel models

Jupyter Notebook

NB the Jupyter notebook includes interactive widgets that are very useful for building an understanding of these models.

Departmental tardiness

Imagine that your university hires a new president. On her first week, she schedules a meeting with professors from three departments, { D = \{\text{G}, \text{E}, \text{M} \} }. The professor from the Department of Government ({ d{=}\text{G} }) shows up 1 minute late, the professor from the Department of English ({ d{=}\text{E} }) shows up 15 minutes before the scheduled meeting time, and the professor from the Department of Math ({ d{=}\text{M} }) shows up 30 minutes late. From this single meeting the new president updates her beliefs about when professors will show up to scheduled meetings.

But how should the president update her beliefs? It could be that departments have different relationships with punctuality and it would be useful to expect professors from the Math department to show up later than professors from the English department. It could also be the case that there is not meaningful between-department variance and she should just learn about the greater population of professors at the university.

A model that infers information about the population without differentiating subpopulations is doing “complete variance pooling”, meaning that all of the variance in the data is treated as being generated by the same distribution. I.e. the data is pooled together and treated as iid observations drawn from a single distribution.

If the president thinks of each department independently, what she learns about the tardiness of professors from one department will not change her expectation about professors from other departments. In this case, there would be “no variance pooling” between departments. The president’s observations of the three professors have no bearing on her belief about what to expect of a professor from a fourth department (e.g. Psychology).

Partial variance pooling” models represent uncertainty at multiple levels of abstraction. Based on her observation from this meeting, the president will update her beliefs about these departments and the greater population. In other words, what she observes about the Math department might strongly update her beliefs about that department, moderately update her beliefs about the broader population, and that in turn could weakly update her beliefs about the English department.

Complete pooling

Let’s build a complete pooling model of the president’s belief about the tardiness of professors. The data observed by the president is { t = \langle t_G, t_E, t_M \rangle = \langle 1, -15, 30 \rangle }. We’ll model the president’s belief about the greater population of professors as a normal distribution centered at \theta with standard deviation \dot\sigma. Prior to meeting with anyone, the president thinks that professors will, on average, show up on time, but that some will show up early and some late. We can encode this prior belief as a normal distribution centered at zero, with some standard deviation, let’s say 20. I.e. the prior for \theta is:

\theta ~\sim~ \mathcal{N}\left( \dot{\mu}{=}0, \dot{\tau}{=}20 \right)

If she observes that professors tend to be late, then her posterior estimate of \theta will have an expected value greater than zero (i.e. { \operatorname{E}[\theta \mid t] > 0 }), and if she ends up expecting professors to be early, then { \operatorname{E}[\theta \mid t] < 0 }.

In this tutorial I’ll use the dot above to indicate that a variable is fixed, not inferred (the dot is not standard notation, just what I’m using here). In the graphical schematic, fixed parameters are depicted as bare symbols rather than nodes. Note that “fixed” means the values are not updated during inference (but you can, of course, manually tune these values or learn suitable values by minimizing some loss function in conjunction with a model fitting algorithm).

Since the model infers the distribution of \theta, the graphical schematic depicts it as an open node (i.e. a latent random variable). The data in this model are t.

We specify the likelihood as,

t_d ~\sim~ \mathcal{N}(\theta, \dot\sigma{=}15)

I.e., for whatever the true value of \theta is for the population of professors, the president thinks that when professors actually show up will follow a normal distribution, centered on the true value of \theta, with standard deviation 15.

Complete Variance Pooling

\begin{align*} \theta ~\sim&~ \mathcal{N}\left( \dot{\mu}{=}0, \dot{\tau}{=}20 \right) \\ t_d ~\sim&~ \mathcal{N}(\theta, \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{ \text{G}, \text{E}, \text{M} \} \end{align*}

%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt

normpdfjit = jax.jit(normpdf)

class Department(IntEnum):
    GOVERNMENT = 0
    ENGLISH = 1
    MATH = 2

t = jnp.array([1, -15, 30])
sigma = 15

Theta = jnp.linspace(-40, 40, 200)

@jax.jit
def professor_arrival_likelihood(d, theta):
    ### likelihood of a professor from department d 
    ### showing up t_d minutes early/late, 
    ### for a given theta and sigma.
    return normpdf(t[d], theta, sigma)

@memo
def complete_pooling[
    _theta: Theta,
](mu=0, tau=1):
    president: knows(_theta)
    president: thinks[
        department: chooses(d in Department, wpp=1),
        department: chooses(theta in Theta, wpp=normpdfjit(theta, mu, tau))
    ]
    president: observes_event(
        wpp=professor_arrival_likelihood(department.d, department.theta))
    return president[Pr[department.theta == _theta]]

mu_ = 0
tau_ = 20
res = complete_pooling(mu=mu_, tau=tau_)

### check the size and sum of the output
# res.shape
# res.sum()

fig, ax = plt.subplots()

ax.axvline(0, color="black", linestyle="-")
theta_posterior = res
theta_expectation = jnp.dot(Theta, theta_posterior)
ax.plot(Theta, theta_posterior, label=r"$P(\theta \mid t)$")
ax.axvline(
    theta_expectation, 
    color='red', 
    linestyle='--', 
    label=(
        r"$\operatorname{E}"
        + rf"[\theta \mid t]={theta_expectation:0.2f}$")
)
_ = ax.set_title(r"Posterior of $\theta$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

_ = plt.suptitle(fr"$\dot\tau$ = {tau_}", y=1)
plt.tight_layout()
plt.show()

Try changing the data to t = jnp.array([-20, -10, 30]). Notice how changing the data for one group updates the expectation of the population.

No pooling

In contrast to complete pooling, which assumes all observations are generated by a shared population-level distribution parameterized by \theta, the no pooling model treats each department as having its own independent parameter \theta_d. The key change is that instead of having a single \theta drawn from the prior distribution, we now have three independent parameters \theta_G, \theta_E, and \theta_M. The practical meaning of this change is significant:

In complete pooling, an observation of the Math professor being late would shift our expectations for all departments. In no pooling, learning that the Math professor is late only updates our beliefs about the Math department. Each department’s parameter is estimated using only data from that department.

The key change in our model specification is that we have replaced the single \theta with three parameters, \theta_d.

No Variance Pooling

\begin{align*} \theta_d ~\sim&~ \mathcal{N}\left( \dot{\mu}{=}0, \dot{\tau}{=}20 \right) \\ t_d ~\sim&~ \mathcal{N}(\theta_d, \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{\text{G}, \text{E}, \text{M} \} \end{align*}

The plate (rectangular box) with label d indicates that everything inside is repeated for each department d \in D. Variables inside the plate have a unique value for each department. Variables outside the plate are shared across all departments.

To change the complete_pooling model to the no_pooling model we make a simple change, adding the conditional statement

president: observes [department.d] is _d

and the query variable _d in the definition, [_d: Department, ...].

%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt

normpdfjit = jax.jit(normpdf)

class Department(IntEnum):
    GOVERNMENT = 0
    ENGLISH = 1
    MATH = 2

t = jnp.array([1, -15, 30])
sigma = 15

Theta = jnp.linspace(-40, 40, 200)

@jax.jit
def professor_arrival_likelihood(d, theta):
    ### likelihood of a professor from department d 
    ### showing up t_d minutes early/late, 
    ### for a given theta and sigma.
    return normpdf(t[d], theta, sigma)

@memo
def no_pooling[
    _d: Department,
    _theta: Theta,
](mu=0, tau=1):
    president: knows(_theta)
    president: thinks[
        department: chooses(d in Department, wpp=1),
        department: chooses(theta in Theta, wpp=normpdfjit(theta, mu, tau))
    ]
    president: observes [department.d] is _d ### new conditional statement
    president: observes_event(
        wpp=professor_arrival_likelihood(department.d, department.theta))
    return president[Pr[department.theta == _theta]]

mu_ = 0
tau_ = 20
res = no_pooling(mu=mu_, tau=tau_)

### check the size and sum of the output
# res.shape
# res.sum()
# res[0].sum()
# res[1].sum()
# res[2].sum()

fig, ax = plt.subplots()

ax.axvline(0, color="black", linestyle="-")
for d in Department:
    department_name = d.name
    department_abbrev = department_name[0]
    theta_posterior = res[d]
    theta_expectation = jnp.dot(Theta, theta_posterior)
    ax.plot(
        Theta, 
        theta_posterior, 
        label=(
            rf"$p(\theta_{department_abbrev} \mid y),~ "
            + r"\operatorname{E}"
            + rf"[\theta_{department_abbrev} \mid t]={theta_expectation:0.2f}$")
    )
_ = ax.set_title(r"Posterior of $\theta_d$")
_ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

_ = plt.suptitle(fr"$\dot\tau$ = {tau_}", y=1)
plt.tight_layout()
plt.show()

Try changing the data to t = jnp.array([-20, -15, 30]). Notice how changing the data for one group does not affect the expectation of the posterior of the other groups.

Previously, the complete_pooling model returned an array of shape (200,) with 200 being the size of the sample space of Theta. With the addition of the conditional statement, what array shape does no_pooling return? Make sure you understand the shape change, what each dimension reflects, and why the values sum in the ways that they do.

The key theoretical distinction is that in no pooling, we’re asserting that the timing behavior of professors in different departments is independent of the behavior of professors from other departments (i.e. when professors from one department show up for meetings carries no information about when professors from a different department will show up). This is an unrealistically strong assumption.

The no_pooling and complete_pooling models highlight an important practical tradeoff: no pooling can better capture differences between departments but may suffer from high variance in its estimates due to small sample sizes within each department. Complete pooling has lower variance but may miss important systematic differences between departments.

Partial pooling

In partial pooling, every observation updates beliefs across multiple levels of abstraction. The critical insight is that while departments may differ in their typical timing behavior, the departments are not statistically independent. They are influenced by common causes. To model the president’s reasoning about these multiple levels of uncertainty, we modify the no_pooling model by adding priors over \mu and \tau.

Rather than being fixed values, \mu and \tau become latent parameters to be jointly inferred alongside \theta_d.

The higher-level latent parameters \mu and \tau characterize a “population-level” distribution from which department-specific \theta_d values are generated.

The key change to the model is that rather than { \theta_d ~\sim~ \mathcal{N}\left( \dot{\mu}{=}0, \dot{\tau}{=}20 \right) }, we have { \theta_d ~\sim~ \mathcal{N}\left(\mu, \tau \right) } and specify hyperpriors on these new latent parameters: { \mu ~\sim~ \mathcal{N}(0, \dot\sigma_{\mu}) } and { \tau ~\sim~ \text{HalfCauchy}(\dot\sigma_{\tau}) }.

Note that there are new fixed parameters: mu_scale ({ \dot\sigma_\mu}) and tau_scale ({ \dot\sigma_\tau }), but I am omitting there from the graphical schematic for the sake of visual simplicity.

Partial Variance Pooling

\begin{align*} \mu ~\sim&~ \mathcal{N}(0, \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}\left(\mu, \tau \right) \\ t_d ~\sim&~ \mathcal{N}(\theta_d, \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{\text{G}, \text{E}, \text{M} \} \end{align*}

%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt

normpdfjit = jax.jit(normpdf)

class Department(IntEnum):
    GOVERNMENT = 0
    ENGLISH = 1
    MATH = 2

t = jnp.array([1, -15, 30])
sigma = 15

Mu = jnp.linspace(-25, 25, 100)  ### sample space for new hyperprior
Tau = jnp.linspace(1, 30, 100)  ### sample space for new hyperprior
Theta = jnp.linspace(-40, 40, 200)

### PDF for new hyperprior
@jax.jit
def half_cauchy(x, scale=1.0):
    return 2 * cauchypdf(x, 0, scale)

@jax.jit
def professor_arrival_likelihood(d, theta):
    ### likelihood of a professor from department d 
    ### showing up t_d minutes early/late, 
    ### for a given theta and sigma.
    return normpdf(t[d], theta, sigma)

@memo
def department_model[_mu: Mu, _tau: Tau](d):
    department: knows(_mu, _tau)
    department: chooses(theta in Theta, wpp=normpdfjit(theta, _mu, _tau))
    return E[ department[professor_arrival_likelihood(d, theta)] ]

@memo
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
    president: knows(_mu, _tau)
    president: thinks[
        population: chooses(mu in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
        population: chooses(tau in Tau, wpp=half_cauchy(tau, tau_scale)),
    ]

    president: observes_event(
        wpp=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
    president: observes_event(
        wpp=department_model[population.mu, population.tau]({Department.ENGLISH}))
    president: observes_event(
        wpp=department_model[population.mu, population.tau]({Department.MATH}))

    return president[Pr[population.mu == _mu, population.tau == _tau]]

Let’s think about what each variable represents.

The likelihood model professor_arrival_likelihood gives the probability that a professor from department d shows up t_d minutes early/late, under a given \theta_d and \sigma:

t_d ~\sim~ \mathcal{N}(\theta_d, \dot\sigma{=}15)

We can write the likelihood as { P(t_d \mid \theta_d, \mu, \tau ; \sigma) }. The marginal posterior { P(\theta_d \mid t) } will thus represent the president’s belief about the true value of \theta_d for a given department d.

@memo
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
    obs: thinks[
        department: chooses(mu in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
        department: chooses(theta in Theta, wpp=normpdfjit(theta, mu, tau))
    ]
    obs: observes_event(wpp=professor_arrival_likelihood(d, department.theta))
    obs: knows(_theta)
    return obs[Pr[department.theta == _theta]]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8), verbose=False):
    posterior = partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale)

    # Marginal over Tau (sum over Mu)
    posterior_tau = posterior.sum(axis=0)
    # Marginal over Mu (sum over Tau)
    posterior_mu = posterior.sum(axis=1)

    fig, axs = plt.subplots(3, 1, figsize=figsize)

    ax = axs[0]
    ax.axvline(0, color="black", linestyle="-")
    ax.plot(Mu, posterior_mu, label=r"$P(\mu \mid t)$")
    mu_expectation = jnp.dot(Mu, posterior_mu)
    ax.axvline(
        mu_expectation, 
        color='red', 
        linestyle='--', 
        label=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:0.2f}$")
    _ = ax.set_title(r"Posterior of $\mu$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    ax = axs[1]
    ax.axvline(0, color="black", linestyle="-")
    ax.plot(Tau, posterior_tau, label=r"$P(\tau \mid t)$")
    tau_expectation = jnp.dot(Tau, posterior_tau)
    ax.axvline(
        tau_expectation, 
        color='red', 
        linestyle='--', 
        label=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:0.2f}$")
    _ = ax.set_title(r"Posterior of $\tau$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    ax = axs[2]
    ax.axvline(0, color="black", linestyle="-")
    for d in Department:
        department_name = d.name
        department_abbrev = department_name[0]
        theta_posterior = department_model_theta(d, mu_scale, tau_scale)
        theta_expectation = jnp.dot(Theta, theta_posterior)
        ax.plot(
            Theta, 
            theta_posterior, 
            label=(
                rf"$P(\theta_{department_abbrev} \mid t),~ " 
                + r"\operatorname{E}" 
                + rf"[\theta_{department_abbrev} \mid t]={theta_expectation:0.2f}$"))
    _ = ax.set_title(r"Posterior of $\theta_d$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    _ = plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
    plt.tight_layout()
    plt.show()

    if verbose:
        for d in Department:
            department_name = d.name
            department_abbrev = department_name[0]
            theta_posterior = department_model_theta(d, mu_scale, tau_scale)
            posterior_mean = jnp.average(Theta, weights=theta_posterior)
            posterior_second_moment = jnp.average(Theta**2, weights=theta_posterior)
            posterior_variance = posterior_second_moment - posterior_mean**2
            print(f"{mu_scale=}, {tau_scale=} : E[θ{department_abbrev} | t]={posterior_mean:0.3f}, Var[θ{department_abbrev} | t]={posterior_variance:0.3f}")

When mu_scale is small (e.g. { \dot\sigma_{\mu}{=}1 }) the location of the distribution over \theta_d will be be centered near zero. When mu_scale is large (e.g. { \dot\sigma_{\mu}{=}10 }), the center of { \mathcal{N}(\mu, \tau) } can move to new locations, resulting in posterior distributions with similar variance with expected values quite different from each other.

param_list = [
    dict(mu_scale=1, tau_scale=3),
    dict(mu_scale=3, tau_scale=3),
    dict(mu_scale=10, tau_scale=3),
]
for i_params_, params_ in enumerate(param_list):
    plot_model(**params_, figsize=(6, 5), verbose=True)

mu_scale=1, tau_scale=3 : E[θG | t]=0.187, Var[θG | t]=29.778
mu_scale=1, tau_scale=3 : E[θE | t]=-2.155, Var[θE | t]=40.307
mu_scale=1, tau_scale=3 : E[θM | t]=6.092, Var[θM | t]=82.286

mu_scale=3, tau_scale=3 : E[θG | t]=0.595, Var[θG | t]=35.435
mu_scale=3, tau_scale=3 : E[θE | t]=-2.168, Var[θE | t]=46.285
mu_scale=3, tau_scale=3 : E[θM | t]=6.891, Var[θM | t]=79.738

mu_scale=10, tau_scale=3 : E[θG | t]=2.370, Var[θG | t]=61.138
mu_scale=10, tau_scale=3 : E[θE | t]=-2.297, Var[θE | t]=73.436
mu_scale=10, tau_scale=3 : E[θM | t]=10.948, Var[θM | t]=80.582



When tau_scale is small (e.g. { \dot\sigma_{\tau}{=}1 }), the distribution { \theta_d ~\sim~ \mathcal{N}(\mu, \tau) } will be narrow, which will pressure the individual { \theta_d } to be similar to each other. When tau_scale is large (e.g. { \dot\sigma_{\tau}{=}20 }), the { \theta_d } can diverge more freely resulting in larger posterior variance, and expected values quite different from each other.

param_list = [
    dict(mu_scale=1, tau_scale=1),
    dict(mu_scale=1, tau_scale=3),
    dict(mu_scale=1, tau_scale=10),
    dict(mu_scale=1, tau_scale=20),
]
for i_params_, params_ in enumerate(param_list):
    plot_model(**params_, figsize=(6, 5), verbose=True)

mu_scale=1, tau_scale=1 : E[θG | t]=0.139, Var[θG | t]=17.512
mu_scale=1, tau_scale=1 : E[θE | t]=-1.247, Var[θE | t]=24.192
mu_scale=1, tau_scale=1 : E[θM | t]=3.798, Var[θM | t]=55.836

mu_scale=1, tau_scale=3 : E[θG | t]=0.187, Var[θG | t]=29.778
mu_scale=1, tau_scale=3 : E[θE | t]=-2.155, Var[θE | t]=40.307
mu_scale=1, tau_scale=3 : E[θM | t]=6.092, Var[θM | t]=82.286

mu_scale=1, tau_scale=10 : E[θG | t]=0.298, Var[θG | t]=57.456
mu_scale=1, tau_scale=10 : E[θE | t]=-4.140, Var[θE | t]=73.415
mu_scale=1, tau_scale=10 : E[θM | t]=10.436, Var[θM | t]=116.686

mu_scale=1, tau_scale=20 : E[θG | t]=0.366, Var[θG | t]=74.551
mu_scale=1, tau_scale=20 : E[θE | t]=-5.324, Var[θE | t]=91.824
mu_scale=1, tau_scale=20 : E[θM | t]=12.698, Var[θM | t]=127.697

In this way, a large { \dot\sigma_{\mu} } and small { \dot\sigma_{\tau} } would allow { \mathcal{N}(\mu, \tau) } to move to easily move to a new location that accommodates the broader tendencies of departments and professors, but would constrain the dispersion between departments. Notice how increasing mu_scale allows the posterior { P(\mu \mid t) } to move away from zero towards the mean of the combined data, whereas decreasing mu_scale pulls the expectation { \operatorname{E}[\mu \mid t] } towards zero. Also notice how this effect is stronger when tau_scale is small.

param_list = [
    dict(mu_scale=1, tau_scale=1),
    # dict(mu_scale=5, tau_scale=5),
    # dict(mu_scale=10, tau_scale=10),
    dict(mu_scale=20, tau_scale=20),
    # dict(mu_scale=10, tau_scale=1),
    dict(mu_scale=20, tau_scale=3),
    # dict(mu_scale=1, tau_scale=10),
    dict(mu_scale=3, tau_scale=20),
]
for i_params_, params_ in enumerate(param_list):
    plot_model(**params_, figsize=(6, 5), verbose=True)

mu_scale=1, tau_scale=1 : E[θG | t]=0.139, Var[θG | t]=17.512
mu_scale=1, tau_scale=1 : E[θE | t]=-1.247, Var[θE | t]=24.192
mu_scale=1, tau_scale=1 : E[θM | t]=3.798, Var[θM | t]=55.836

mu_scale=20, tau_scale=20 : E[θG | t]=2.517, Var[θG | t]=112.461
mu_scale=20, tau_scale=20 : E[θE | t]=-6.000, Var[θE | t]=129.325
mu_scale=20, tau_scale=20 : E[θM | t]=16.936, Var[θM | t]=112.901

mu_scale=20, tau_scale=3 : E[θG | t]=3.228, Var[θG | t]=74.198
mu_scale=20, tau_scale=3 : E[θE | t]=-2.422, Var[θE | t]=87.779
mu_scale=20, tau_scale=3 : E[θM | t]=13.039, Var[θM | t]=83.766

mu_scale=3, tau_scale=20 : E[θG | t]=0.632, Var[θG | t]=78.673
mu_scale=3, tau_scale=20 : E[θE | t]=-5.378, Var[θE | t]=96.102
mu_scale=3, tau_scale=20 : E[θM | t]=13.063, Var[θM | t]=124.248

Multiple observations with missing data

What if the president instead had a large meeting with 7 professors and the departments were not equally represented? How would she update her beliefs given unequal observations? Let’s say that she invited 3 professors from the Department of Government, 3 from the Department of English, and 1 from the Department of Math. We can explore how modulating the higher-level hyperpriors changes her posterior inference about each department.

\begin{align*} \mu ~\sim&~ \mathcal{N}(0, \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}\left(\mu, \tau \right) \\ t_{(d,i)} ~\sim&~ \mathcal{N}(\theta_d, \dot\sigma{=}15) \\ d ~\in&~ D,~~~ \text{where}~~~ D = \{\text{G}, \text{E}, \text{M} \} \end{align*}

%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.norm import logpdf as normlogpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt

normpdfjit = jax.jit(normpdf)

class Department(IntEnum):
    GOVERNMENT = 0
    ENGLISH = 1
    MATH = 2

###NEW
t = jnp.array([
    [-10, 1, 11], 
    [-16, -15, -14], 
    [30, jnp.nan, jnp.nan],
])

sigma = 15

Mu = jnp.linspace(-25, 25, 100)
Tau = jnp.linspace(1, 30, 100)
Theta = jnp.linspace(-40, 40, 200)

@jax.jit
def half_cauchy(x, scale=1.0):
    return 2 * cauchypdf(x, 0, scale)

@jax.jit
def professor_arrival_likelihood(d, theta):
    ### likelihood of a professor from department d 
    ### showing up t_d minutes early/late, 
    ### for a given theta and sigma.
    return jnp.exp(jnp.nansum(normlogpdf(t[d], theta, sigma)))  ###NEW

@memo
def department_model[_mu: Mu, _tau: Tau](d):
    department: knows(_mu, _tau)
    department: chooses(theta in Theta, wpp=normpdfjit(theta, _mu, _tau))
    return E[ department[professor_arrival_likelihood(d, theta)] ]

@memo
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
    president: knows(_mu, _tau)
    president: thinks[
        population: chooses(mu in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
        population: chooses(tau in Tau, wpp=half_cauchy(tau, tau_scale)),
    ]
    president: observes_event(wpp=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
    president: observes_event(wpp=department_model[population.mu, population.tau]({Department.ENGLISH}))
    president: observes_event(wpp=department_model[population.mu, population.tau]({Department.MATH}))
    return president[Pr[population.mu == _mu, population.tau == _tau]]

@memo
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
    obs: thinks[
        department: chooses(mu in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
        department: chooses(theta in Theta, wpp=normpdfjit(theta, mu, tau))
    ]
    obs: observes_event(wpp=professor_arrival_likelihood(d, department.theta))
    obs: knows(_theta)
    return obs[Pr[department.theta == _theta]]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8)):
    posterior = partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale)

    # Marginal over Tau (sum over Mu)
    posterior_tau = posterior.sum(axis=0)
    # Marginal over Mu (sum over Tau)
    posterior_mu = posterior.sum(axis=1)

    fig, axs = plt.subplots(3, 1, figsize=figsize)

    ax = axs[0]
    ax.axvline(0, color="black", linestyle="-")
    ax.plot(Mu, posterior_mu, label=r"$P(\mu \mid t)$")
    mu_expectation = jnp.dot(Mu, posterior_mu)
    ax.axvline(
        mu_expectation, 
        color='red', 
        linestyle='--', 
        label=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:0.2f}$")
    _ = ax.set_title(r"Posterior of $\mu$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    ax = axs[1]
    ax.axvline(0, color="black", linestyle="-")
    ax.plot(Tau, posterior_tau, label=r"$P(\tau \mid t)$")
    tau_expectation = jnp.dot(Tau, posterior_tau)
    ax.axvline(
        tau_expectation, 
        color='red', 
        linestyle='--', 
        label=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:0.2f}$")
    _ = ax.set_title(r"Posterior of $\tau$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    ax = axs[2]
    ax.axvline(0, color="black", linestyle="-")
    for d in Department:
        department_name = d.name
        department_abbrev = department_name[0]
        theta_posterior = department_model_theta(d, mu_scale, tau_scale)
        theta_expectation = jnp.dot(Theta, theta_posterior)
        ax.plot(
            Theta, 
            theta_posterior, 
            label=(
                rf"$P(\theta_{department_abbrev} \mid t),~ " 
                + r"\operatorname{E}" 
                + rf"[\theta_{department_abbrev} \mid t]={theta_expectation:0.2f}$"))
    _ = ax.set_title(r"Posterior of $\theta_d$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    _ = plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
    plt.tight_layout()
    plt.show()
param_list = [
    dict(mu_scale=1, tau_scale=1),
    # dict(mu_scale=5, tau_scale=5),
    # dict(mu_scale=10, tau_scale=10),
    dict(mu_scale=20, tau_scale=20),
    # dict(mu_scale=10, tau_scale=1),
    dict(mu_scale=20, tau_scale=1),
    # dict(mu_scale=1, tau_scale=10),
    dict(mu_scale=1, tau_scale=20),
]
for i_params_, params_ in enumerate(param_list):
    plot_model(**params_, figsize=(6, 5))

When tau_scale is small (e.g. { \dot\sigma_{\tau}{=}1 }), the distribution { \theta_d ~\sim~ \mathcal{N}(\mu, \tau) } will be narrow, which will pressure the individual { \theta_d } to be similar to each other. When tau_scale is large (e.g. { \dot\sigma_{\tau}{=}20 }), the { \theta_d } can diverge more freely resulting in values quite different from each other. Think about how this will interact with the number of observations and the values of those observations.

When mu_scale is small (e.g. { \dot\sigma_{\mu}{=}1 }) the location of the distribution over \theta_d will be be centered near zero. When mu_scale is large (e.g. { \dot\sigma_{\mu}{=}20 }), the center of { \mathcal{N}(\mu, \tau) } can move to new locations.

In this way, a large { \dot\sigma_{\mu} } and small { \dot\sigma_{\tau} } would allow { \mathcal{N}(\mu, \tau) } to move to easily move to a new location that accommodates the broader tendencies of departments and professors, but would constrain the dispersion between departments. Notice how increasing mu_scale allows the posterior { P(\mu \mid t) } to move away from zero towards the mean of the combined data, whereas decreasing mu_scale pulls the expectation { \operatorname{E}[\mu \mid t] } towards zero. Also notice how this effect is stronger when tau_scale is small.

Pay close attention to the interaction between mu_scale, tau_scale, and the observations for each department (including the expectation and variance of the observations of each department).

Multiple observations with missing data and learned variance

Partial Variance Pooling with learned \sigma_d

\begin{align*} \mu ~\sim&~ \mathcal{N}(0, \dot\sigma_{\mu}) \\ \tau ~\sim&~ \text{HalfCauchy}(\dot\sigma_{\tau}) \\ \theta_d ~\sim&~ \mathcal{N}(\mu, \tau) \\ \sigma_d ~\sim&~ \text{HalfCauchy}(5) \\ t_{(d,i)} ~\sim&~ \mathcal{N}(\theta_d, \sigma_d) \\ d ~\in&~ \{ 0, 1, 2\} \end{align*}

%reset -f
import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum
from jax.scipy.stats.norm import pdf as normpdf
from jax.scipy.stats.norm import logpdf as normlogpdf
from jax.scipy.stats.cauchy import pdf as cauchypdf
from matplotlib import pyplot as plt

normpdfjit = jax.jit(normpdf)

class Department(IntEnum):
    GOVERNMENT = 0
    ENGLISH = 1
    MATH = 2

t = jnp.array([
    [-10, 1, 11], 
    [-16, -15, -14], 
    [30, jnp.nan, jnp.nan],
])

Mu = jnp.linspace(-25, 25, 100)
Tau = jnp.linspace(1, 30, 100)
Theta = jnp.linspace(-40, 40, 200)
Sigma = jnp.linspace(1, 30, 100)  ###NEW

@jax.jit
def half_cauchy(x, scale=1.0):
    return 2 * cauchypdf(x, 0, scale)

@jax.jit
def professor_arrival_likelihood(d, theta, sigma): ###NEW
    ### likelihood of a professor from department d 
    ### showing up t_d minutes early/late, 
    ### for a given theta and sigma.
    return jnp.exp(jnp.nansum(normlogpdf(t[d], theta, sigma)))

@memo(cache=True)
def department_model[_mu: Mu, _tau: Tau](d):
    department: knows(_mu, _tau)
    department: chooses(theta in Theta, wpp=normpdfjit(theta, _mu, _tau))
    department: chooses(sigma in Sigma, wpp=half_cauchy(sigma, 5))  ###NEW
    return E[ department[professor_arrival_likelihood(d, theta, sigma)] ]  ###NEW

@memo(cache=True)
def partial_pooling[_mu: Mu, _tau: Tau](mu_scale=5, tau_scale=5):
    president: knows(_mu, _tau)
    president: thinks[
        population: chooses(mu in Mu, wpp=normpdfjit(mu, 0, mu_scale)),
        population: chooses(tau in Tau, wpp=half_cauchy(tau, tau_scale)),
    ]
    president: observes_event(wpp=department_model[population.mu, population.tau]({Department.GOVERNMENT}))
    president: observes_event(wpp=department_model[population.mu, population.tau]({Department.ENGLISH}))
    president: observes_event(wpp=department_model[population.mu, population.tau]({Department.MATH}))
    return president[Pr[population.mu == _mu, population.tau == _tau]]

@memo(cache=True)
def department_model_theta[_theta: Theta](d, mu_scale, tau_scale):
    obs: thinks[
        department: chooses(mu in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
        department: chooses(theta in Theta, wpp=normpdfjit(theta, mu, tau)),
        department: chooses(sigma in Sigma, wpp=half_cauchy(sigma, 5)),  ###NEW
    ]
    obs: observes_event(wpp=professor_arrival_likelihood(d, department.theta, department.sigma))  ###NEW
    obs: knows(_theta)
    return obs[Pr[department.theta == _theta]]

@memo(cache=True)
def department_model_sigma[_sigma: Sigma](d, mu_scale, tau_scale):  ###NEW
    obs: thinks[
        department: chooses(mu in Mu, tau in Tau, wpp=partial_pooling[mu, tau](mu_scale, tau_scale)),
        department: chooses(theta in Theta, wpp=normpdfjit(theta, mu, tau)),
        department: chooses(sigma in Sigma, wpp=half_cauchy(sigma, 5)),
    ]
    obs: observes_event(wpp=professor_arrival_likelihood(d, department.theta, department.sigma))
    obs: knows(_sigma)
    return obs[Pr[department.sigma == _sigma]]
Caching function
# Cache for precomputed results
cache = {}

def compute_distributions(mu_scale, tau_scale, verbose=False):
    """Retrieve from cache or compute the JAX-based distributions."""
    key = (mu_scale, tau_scale)
    if key in cache:
        return cache[key]  # Use cached results

    if verbose:
        print(f"{mu_scale=}, {tau_scale=}")

    # Cache results
    cache[key] = (
        partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale).sum(axis=1), 
        partial_pooling(mu_scale=mu_scale, tau_scale=tau_scale).sum(axis=0), 
        {d: department_model_sigma(d, mu_scale, tau_scale) for d in Department}, 
        {d: department_model_theta(d, mu_scale, tau_scale) for d in Department},
    )
    return cache[key]
Plotting function
def plot_model(mu_scale=1, tau_scale=1, figsize=(10, 8)):
    posterior_mu, posterior_tau, sigma_posteriors, theta_posteriors = compute_distributions(mu_scale, tau_scale)

    fig, axs = plt.subplots(4, 1, figsize=figsize)

    ax = axs[0]
    ax.axvline(0, color="black", linestyle="-")
    ax.plot(Mu, posterior_mu, label=r"$P(\mu \mid t)$")
    mu_expectation = jnp.dot(Mu, posterior_mu)
    ax.axvline(
        mu_expectation, 
        color='red', 
        linestyle='--', 
        label=r"$\operatorname{E}" + rf"[\mu \mid t]={mu_expectation:0.2f}$")
    _ = ax.set_title(r"Posterior of $\mu$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    ax = axs[1]
    ax.axvline(0, color="black", linestyle="-")
    ax.plot(Tau, posterior_tau, label=r"$P(\tau \mid t)$")
    tau_expectation = jnp.dot(Tau, posterior_tau)
    ax.axvline(
        tau_expectation, 
        color='red', 
        linestyle='--', 
        label=r"$\operatorname{E}" + rf"[\tau \mid t]={tau_expectation:0.2f}$")
    _ = ax.set_title(r"Posterior of $\tau$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    ax = axs[2]
    ax.axvline(0, color="black", linestyle="-")
    for d in Department:
        department_name = d.name
        department_abbrev = department_name[0]
        sigma_posterior = sigma_posteriors[d]
        sigma_expectation = jnp.dot(Sigma, sigma_posterior)
        ax.plot(
            Sigma, 
            sigma_posterior, 
            label=(
                rf"$P(\sigma_{department_abbrev} \mid t),~ " 
                + r"\operatorname{E}" 
                + rf"[\sigma_{department_abbrev} \mid t]={sigma_expectation:0.2f}$"))
    _ = ax.set_xlim(-1, 20)
    _ = ax.set_title(r"Posterior of $\sigma_d$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    ax = axs[3]
    ax.axvline(0, color="black", linestyle="-")
    for d in Department:
        department_name = d.name
        department_abbrev = department_name[0]
        theta_posterior = theta_posteriors[d]
        theta_expectation = jnp.dot(Theta, theta_posterior)
        ax.plot(
            Theta, 
            theta_posterior, 
            label=(
                rf"$P(\theta_{department_abbrev} \mid t),~ " 
                + r"\operatorname{E}" 
                + rf"[\theta_{department_abbrev} \mid t]={theta_expectation:0.2f}$"))
    _ = ax.set_xlim(-20, 40)
    _ = ax.set_title(r"Posterior of $\theta_d$")
    _ = ax.legend(bbox_to_anchor=(0.9, 0.5), loc='center left')

    _ = plt.suptitle(f"mu_scale = {mu_scale}, tau_scale = {tau_scale}", y=1)
    plt.tight_layout()
    plt.show()
param_list = [
    dict(mu_scale=1, tau_scale=1),
    # dict(mu_scale=5, tau_scale=5),
    dict(mu_scale=10, tau_scale=10),
    # dict(mu_scale=20, tau_scale=20),
    dict(mu_scale=10, tau_scale=1),
    # dict(mu_scale=20, tau_scale=1),
    dict(mu_scale=1, tau_scale=10),
    # dict(mu_scale=1, tau_scale=20),
]
plt.close("all")
for i_params_, params_ in enumerate(param_list):
    plot_model(**params_, figsize=(10, 8))

Further reading

%reset -f
import sys
import platform
import importlib.metadata

print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())

print("\nPackages:")
for name, version in sorted(
    ((dist.metadata["Name"], dist.version) for dist in importlib.metadata.distributions()),
    key=lambda x: x[0].lower()  # Sort case-insensitively
):
    print(f"{name}=={version}")
Python: 3.13.2 (main, Feb  5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64

Packages:
annotated-types==0.7.0
anyio==4.8.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
async-lru==2.0.4
attrs==25.1.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.17.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.8
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.1
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.5.2
jaxlib==0.5.1
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
memo-lang==1.1.0
mistune==3.1.2
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.3
opt_einsum==3.4.0
optype==0.9.1
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.241126
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.6
plotly==5.24.1
pre_commit==4.1.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pyparsing==3.2.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==3.3.0
pytz==2025.1
PyYAML==6.0.2
pyzmq==26.2.1
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.9.10
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.0
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==75.8.2
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.38
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.4.0
toml==0.10.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250204
typing_extensions==4.12.2
tzdata==2025.1
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.1.2
zipp==3.21.0

References

Gelman, Andrew. (2014). Bayesian data analysis (Third edition). CRC Press.
Kruschke, John K. (2015). Doing Bayesian data analysis: A tutorial with R, JAGS, and Stan (Edition 2). Academic Press.
McElreath, Richard. (2016). Statistical rethinking: A Bayesian course with examples in R and Stan. CRC Press/Taylor & Francis Group.