import jax
import jax.numpy as jnp
from memo import memo
from memo import domain as product
from enum import IntEnum
from matplotlib import pyplot as plt
Conditional dependence
Patterns of inference as evidence changes
Causal motifs
The Fork
• A and B are associated: {A \not\perp B}
• Share a common cause Z
• Once stratified by Z, no association: {A \perp B \mid Z}
The Pipe
• A and B are associated: {A \not\perp B}
• Influence of A on B transmitted through Z
• Once stratified by Z, no association: {A \perp B \mid Z}
The Collider
• A and B are not associated (no shared causes): {A \perp B}
• A and B both influence Z
• Once stratified by Z, A and B are associated: {A \not\perp B \mid Z}
The Descendent
• A and B are causally associated through Z: {A \not\perp B}
• C holds information about Z
• Once stratified by C, A and B are less associated: {A \perp B \mid C} (if strong enough)
Conditional independence
Recall the definition of (in)dependence. We define conditional (in)dependence in effectively the same way, just conditioning everything on a new random variable, Z. So whereas (marginal) independence is defined as { A \perp B \iff P(A, \, B) = P(A) \; P(B) }, conditional independence is defined as:
A \perp B \mid Z \iff P(A, \, B \mid Z) = P(A \mid Z) \; P(B \mid Z) \\ \text{where}~~~ P(Z) > 0 \\ \forall~ (a, b, z) \in \{ A \times B \times Z \}
And in the same way as we saw for marginal independence,
A \perp B \mid Z ~~\iff~~ P(A \mid B, \, Z) = P(A \mid Z)
Exercise: Show that, if {A \perp B \mid Z}, then {P(B \mid A, \, Z) = P(B \mid Z)}. Hint: Follow the chain rule.
In general (regardless of dependence),
P(A, \, B \mid Z) = P(A \mid Z) \; P(B \mid A, \, Z) \\ \forall~ (a, b, z) \in \{ A \times B \times Z \} ~~~\text{where}~~~ P(A, \, Z) > 0
If {A \perp B \mid Z}, then {P(A, \, B \mid Z) = P(A \mid Z) \; P(B \mid Z)}.
In which case,
P(A \mid Z) \; P(B \mid Z) = P(A \mid Z) \; P(B \mid A, \, Z)
and {P(B \mid Z) = P(B \mid A, \, Z)}.
From A Priori Dependence to Conditional Dependence
The relationships between causal structure and statistical dependence become particularly interesting and subtle when we look at the effects of additional observations or assumptions. Events that are statistically dependent a priori may become independent when we condition on some observation; this is called screening off. Also, events that are statistically independent a priori may become dependent when we condition on observations; this is known as explaining away. The dynamics of screening off and explaining away are extremely important for understanding patterns of inference—reasoning and learning—in probabilistic models.
Screening off
Screening off refers to a pattern of statistical inference that is quite common in both scientific and intuitive reasoning. If the statistical dependence between two events A and B is only indirect, mediated strictly by one or more other events Z, then conditioning on (observing) Z should render A and B statistically independent. This can occur if A and B are connected by one or more causal chains, and all such chains run through the set of events Z, or if Z comprises all of the common causes of A and B.
Fork
For instance, let’s look at a common cause example (a “fork”). Here, A and B are associated ({A \not\perp B}). We can determine this by observing that { P(A \mid B{=}\mathtt{true}) \neq P(A \mid B{=}\mathtt{false}) }.
class A(IntEnum):
= 0
FALSE = 1
TRUE
class B(IntEnum):
= 0
FALSE = 1
TRUE
class Z(IntEnum):
= 0
FALSE = 1
TRUE
@memo
def fork[_a: A, _b: B]():
agent: knows(_a, _b)
agent: thinks[
friend: knows(_a, _b),in Z, wpp=1),
friend: chooses(z in A, wpp=(
friend: chooses(a 1 if z == {Z.TRUE}
else (0.9 if a == {A.TRUE} else 0.1))),
in B, wpp=(
friend: chooses(b
(0.1 if b == {B.TRUE} else 0.9
if z == 1
) else (0.4 if b == {B.TRUE} else 0.6))),
]
### this observes statement is just here to
### make it easier to inspect the results
is _b
agent: observes [friend.b]
return agent[Pr[friend.a == _a]]
= fork(print_table=True, return_aux=True, return_xarray=True)
res
= res.aux.xarray
resx print("\n")
for a in A:
for b in B:
= a.name, b.name
a_, b_ print(f"P(A={a_} | B={b_}) = {resx.loc[a_, b_].item():0.4f}")
print("---")
+-------+-------+----------------------+
| _a: A | _b: B | fork |
+-------+-------+----------------------+
| FALSE | FALSE | 0.3399999737739563 |
| FALSE | TRUE | 0.18000000715255737 |
| TRUE | FALSE | 0.6599999666213989 |
| TRUE | TRUE | 0.8199999928474426 |
+-------+-------+----------------------+
P(A=FALSE | B=FALSE) = 0.3400
P(A=FALSE | B=TRUE) = 0.1800
---
P(A=TRUE | B=FALSE) = 0.6600
P(A=TRUE | B=TRUE) = 0.8200
---
Note: The DAG above shows the model {P(A, B, Z)}. However, the @memo
(fork
) is conditioned on B — this done purely because it’s easier to visually inspect if { P(A \mid B) = P(A \mid \neg B) = P(A) } than it is to inspect if { P(A, B) = P(A) \; P(B) }.
The conceptually important part is what happens when the model is conditioned on Z, so while I am conditioning the model on B, I’m not shading the B node in the DAG.
We could, of course, run the model that’s actually shown in the DAG by removing the agent: observes [friend.b] is _b
line, and then check if { P(A, B) = P(A) \; P(B) } (see the Without extra conditional details for a demonstration).
In the joint model, { P(A, B) \neq P(A) \; P(B) }.
@memo
def fork_joint[_a: A, _b: B]():
agent: knows(_a, _b)
agent: thinks[
friend: knows(_a, _b),in Z, wpp=1),
friend: chooses(z in A, wpp=(
friend: chooses(a 1 if z == {Z.TRUE}
else (0.9 if a == {A.TRUE} else 0.1))),
in B, wpp=(
friend: chooses(b
(0.1 if b == {B.TRUE} else 0.9
if z == 1
) else (0.4 if b == {B.TRUE} else 0.6))),
]
### without observes statement
# agent: observes [friend.b] is _b
return agent[Pr[friend.a == _a, friend.b == _b]]
= fork_joint(print_table=True, return_aux=True, return_xarray=True)
res
= res.aux.xarray
resx print("\nCompare:")
for a in A:
for b in B:
= a.name, b.name
a_, b_ print("\n")
print(f"P(A={a_}, B={b_}) = {resx.loc[a_, b_].item():0.4f}")
print(f"P(A={a_}) P(B={b_}) = {resx.loc[a_, :].sum().item() * resx.loc[:, b_].sum().item():0.4f}")
+-------+-------+----------------------+
| _a: A | _b: B | fork_joint |
+-------+-------+----------------------+
| FALSE | FALSE | 0.2549999952316284 |
| FALSE | TRUE | 0.04500000178813934 |
| TRUE | FALSE | 0.4950000047683716 |
| TRUE | TRUE | 0.20499999821186066 |
+-------+-------+----------------------+
Compare:
P(A=FALSE, B=FALSE) = 0.2550
P(A=FALSE) P(B=FALSE) = 0.2250
P(A=FALSE, B=TRUE) = 0.0450
P(A=FALSE) P(B=TRUE) = 0.0750
P(A=TRUE, B=FALSE) = 0.4950
P(A=TRUE) P(B=FALSE) = 0.5250
P(A=TRUE, B=TRUE) = 0.2050
P(A=TRUE) P(B=TRUE) = 0.1750
But once the stratified by Z, { P(A, B \mid Z) } is always equal to { P(A \mid Z) \; P(B \mid Z) }.
@memo
def fork_joint__z[_z: Z, _a: A, _b: B]():
agent: knows(_a, _b, _z)
agent: thinks[
friend: knows(_a, _b),in Z, wpp=1),
friend: chooses(z in A, wpp=(
friend: chooses(a 1 if z == {Z.TRUE}
else (0.9 if a == {A.TRUE} else 0.1))),
in B, wpp=(
friend: chooses(b
(0.1 if b == {B.TRUE} else 0.9
if z == 1
) else (0.4 if b == {B.TRUE} else 0.6))),
]
is _z
agent: observes [friend.z]
### without observes statement
# agent: observes [friend.b] is _b
return agent[Pr[friend.a == _a, friend.b == _b]]
= fork_joint__z(print_table=True, return_aux=True, return_xarray=True)
res
= res.aux.xarray
resx
print("\nCompare:")
for z in Z:
for a in A:
for b in B:
= z.name, a.name, b.name
z_, a_, b_ print("\n")
print(f" P(A={a_}, B={b_} | Z={z_}) = {resx.loc[z_, a_, b_].item():0.4f}")
print(f" P(A={a_} | Z={z_}) P(B={b_} | Z={z_}) = {resx.loc[z_, a_, :].sum().item() * resx.loc[z_, :, b_].sum().item():0.4f}")
+-------+-------+-------+----------------------+
| _z: Z | _a: A | _b: B | fork_joint__z |
+-------+-------+-------+----------------------+
| FALSE | FALSE | FALSE | 0.06000000238418579 |
| FALSE | FALSE | TRUE | 0.04000000283122063 |
| FALSE | TRUE | FALSE | 0.5400000214576721 |
| FALSE | TRUE | TRUE | 0.35999998450279236 |
| TRUE | FALSE | FALSE | 0.44999998807907104 |
| TRUE | FALSE | TRUE | 0.05000000074505806 |
| TRUE | TRUE | FALSE | 0.44999998807907104 |
| TRUE | TRUE | TRUE | 0.05000000074505806 |
+-------+-------+-------+----------------------+
Compare:
P(A=FALSE, B=FALSE | Z=FALSE) = 0.0600
P(A=FALSE | Z=FALSE) P(B=FALSE | Z=FALSE) = 0.0600
P(A=FALSE, B=TRUE | Z=FALSE) = 0.0400
P(A=FALSE | Z=FALSE) P(B=TRUE | Z=FALSE) = 0.0400
P(A=TRUE, B=FALSE | Z=FALSE) = 0.5400
P(A=TRUE | Z=FALSE) P(B=FALSE | Z=FALSE) = 0.5400
P(A=TRUE, B=TRUE | Z=FALSE) = 0.3600
P(A=TRUE | Z=FALSE) P(B=TRUE | Z=FALSE) = 0.3600
P(A=FALSE, B=FALSE | Z=TRUE) = 0.4500
P(A=FALSE | Z=TRUE) P(B=FALSE | Z=TRUE) = 0.4500
P(A=FALSE, B=TRUE | Z=TRUE) = 0.0500
P(A=FALSE | Z=TRUE) P(B=TRUE | Z=TRUE) = 0.0500
P(A=TRUE, B=FALSE | Z=TRUE) = 0.4500
P(A=TRUE | Z=TRUE) P(B=FALSE | Z=TRUE) = 0.4500
P(A=TRUE, B=TRUE | Z=TRUE) = 0.0500
P(A=TRUE | Z=TRUE) P(B=TRUE | Z=TRUE) = 0.0500
Now let’s assume that we already know the value of Z
:
@memo
def fork__z[_z: Z, _a: A, _b: B]():
agent: knows(_a, _b, _z)
agent: thinks[
friend: knows(_a, _b),in Z, wpp=1),
friend: chooses(z in A, wpp=(
friend: chooses(a 1 if z == {Z.TRUE}
else (0.9 if a == {A.TRUE} else 0.1))),
in B, wpp=(
friend: chooses(b
(0.1 if b == {B.TRUE} else 0.9
if z == {Z.TRUE}
) else (0.4 if b == {B.TRUE} else 0.6))),
]
is _z
agent: observes [friend.z]
### this observes statement is just here to
### make it easier to inspect the results
is _b
agent: observes [friend.b]
return agent[Pr[friend.a == _a]]
= fork__z(print_table=True, return_aux=True, return_xarray=True) res
+-------+-------+-------+----------------------+
| _z: Z | _a: A | _b: B | fork__z |
+-------+-------+-------+----------------------+
| FALSE | FALSE | FALSE | 0.10000000149011612 |
| FALSE | FALSE | TRUE | 0.10000001639127731 |
| FALSE | TRUE | FALSE | 0.8999999761581421 |
| FALSE | TRUE | TRUE | 0.9000000357627869 |
| TRUE | FALSE | FALSE | 0.5 |
| TRUE | FALSE | TRUE | 0.5 |
| TRUE | TRUE | FALSE | 0.5 |
| TRUE | TRUE | TRUE | 0.5 |
+-------+-------+-------+----------------------+
We see that {A \perp B \mid Z}. While A and B are marginally _de_pendent, they are conditionally independent given knowledge of Z
.
E.g.
print(f"if A and B are conditionally independent given Z, then...")
= res.aux.xarray
resx for z in Z:
for a in A:
print("\nThese should be the same:")
for b in B:
= z.name, a.name, b.name
z_, a_, b_ print(f" P(A={a_} | B={b_}, Z={z_}) = {resx.loc[z_, a_, b_].item():0.4f}")
if A and B are conditionally independent given Z, then...
These should be the same:
P(A=FALSE | B=FALSE, Z=FALSE) = 0.1000
P(A=FALSE | B=TRUE, Z=FALSE) = 0.1000
These should be the same:
P(A=TRUE | B=FALSE, Z=FALSE) = 0.9000
P(A=TRUE | B=TRUE, Z=FALSE) = 0.9000
These should be the same:
P(A=FALSE | B=FALSE, Z=TRUE) = 0.5000
P(A=FALSE | B=TRUE, Z=TRUE) = 0.5000
These should be the same:
P(A=TRUE | B=FALSE, Z=TRUE) = 0.5000
P(A=TRUE | B=TRUE, Z=TRUE) = 0.5000
Continuous Fork
Joint Probability Heatmap
### Import
def lobf(x, y):
import numpy as np
return (np.unique(x), np.poly1d(jnp.polyfit(x, y, 1))(np.unique(x)))
def jointprob_facetgrid(P, support_dim0=None, support_dim1=None, origin=True, corr=True, cmap="BuPu", dim0: dict | None = None, dim1: dict | None = None, **kwargs):
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
if support_dim0 is None:
= np.arange(P.shape[0])
support_dim0 if support_dim1 is None:
= np.arange(P.shape[1])
support_dim1 if isinstance(dim0, dict) and "support" in dim0:
= dim0["support"]
support_dim0 if isinstance(dim1, dict) and "support" in dim1:
= dim1["support"]
support_dim1
= "dim 0"
label_dim0 = "dim 1"
label_dim1 = None
lim_dim0 = None
lim_dim1
= kwargs.get("label_dim0", label_dim0)
label_dim0 = kwargs.get("label_dim1", label_dim1)
label_dim1
= kwargs.get("lim_dim0", lim_dim0)
lim_dim0 = kwargs.get("lim_dim0", lim_dim1)
lim_dim1
= kwargs.get("title", None)
title
if isinstance(dim0, dict):
= dim0.get("label", label_dim0)
label_dim0 = dim0.get("lim", lim_dim0)
lim_dim0 if isinstance(dim1, dict):
= dim1.get("label", label_dim1)
label_dim1 = dim0.get("lim", lim_dim1)
lim_dim1
assert len(P.shape) == 2
assert len(support_dim0.shape) == 1
assert len(support_dim1.shape) == 1
# Create the grid for the joint plot with shared axes
= sns.JointGrid(ratio=8, height=4)
g
g.ax_marg_x.sharex(g.ax_joint)
g.ax_marg_y.sharey(g.ax_joint)
# Plot the main density with contours
= g.ax_joint.contourf(support_dim0, support_dim1, P.T, levels=20, cmap=cmap, norm="linear", )
contour =10, colors="white", alpha=0.3, linewidths=0.5)
g.ax_joint.contour(support_dim0, support_dim1, P.T, levels
= g.ax_joint.get_xlim()
xlim = g.ax_joint.get_ylim()
ylim
if origin:
= g.ax_joint.axhline(0, color="white", linewidth=0.5, alpha=0.3)
_ = g.ax_joint.axvline(0, color="white", linewidth=0.5, alpha=0.3)
_
# Plot marginal distributions with aligned axes
'#f0f0f0')
g.ax_marg_x.set_facecolor('#f0f0f0')
g.ax_marg_y.set_facecolor(sum(axis=1), alpha=0.3)
g.ax_marg_x.fill_between(support_dim0, P.sum(axis=0), alpha=0.3)
g.ax_marg_y.fill_betweenx(support_dim1, P.
if corr:
# Calculate expected values
= np.meshgrid(support_dim0, support_dim1)
A, B = np.sum(A * P.T) / np.sum(P)
E_a = np.sum(B * P.T) / np.sum(P)
E_b
# Calculate covariance
= np.sum((A - E_a) * (B - E_b) * P.T) / np.sum(P)
cov_ab = np.sum((A - E_a)**2 * P.T) / np.sum(P)
var_a = np.sum((B - E_b)**2 * P.T) / np.sum(P)
var_b
# Calculate correlation
= cov_ab / np.sqrt(var_a * var_b)
correlation
# Calculate line of best fit
= cov_ab / var_a
slope = E_b - slope * E_a
intercept
# Plot line of best fit
= np.array([support_dim0[0], support_dim0[-1]])
line_x = slope * line_x + intercept
line_y "w--", linewidth=1, label=f'r = {correlation:.2f}')
g.ax_joint.plot(line_x, line_y, # g.ax_joint.legend()
g.ax_joint.set_xlim(xlim)
g.ax_joint.set_ylim(ylim)
= np.sum(support_dim0 * P.sum(axis=1))
E_d0 = np.sum(support_dim1 * P.sum(axis=0))
E_d1 0.0, color="red", s=15, marker="v", label="E[d0]")
g.ax_marg_x.scatter(E_d0, 0.0, E_d1, color="red", s=15, marker="<", label="E[d1]")
g.ax_marg_y.scatter(
=False, direction='in')
g.ax_marg_y.tick_params(labelleft
# Remove ticks from marginal plots
# g.ax_marg_x.tick_params(labelbottom=False)
# g.ax_marg_y.tick_params(labelleft=False)
=False, direction='in')
g.ax_marg_x.tick_params(labelbottom=False, direction='in')
g.ax_marg_y.tick_params(labelleft
if label_dim0 is not None:
= g.ax_joint.set_xlabel(label_dim0)
_ if label_dim0 is not None:
= g.ax_joint.set_ylabel(label_dim1)
_
if lim_dim0 is not None:
= g.ax_joint.set_xlim(lim_dim0)
_ if lim_dim1 is not None:
= g.ax_joint.set_ylim(lim_dim1)
_
if title:
= g.figure.suptitle(title, y=1.03)
_
return g
from jax.scipy.stats.norm import pdf as normpdf
= jax.jit(normpdf)
normpdfjit
= jnp.linspace(-3, 3, 11)
W = jnp.linspace(-3, 3, 11)
T = jnp.arange(2)
R
@memo
def viz_fork_joint[_w: W, _t: T]():
agent: knows(_w, _t)
agent: thinks[in R, wpp=1),
friend: chooses(r in W, wpp=normpdfjit(w, 2*r-1, 0.5)),
friend: chooses(w in T, wpp=normpdfjit(t, 2*r-1, 0.5)),
friend: chooses(t
]return agent[Pr[friend.w == _w, friend.t == _t]]
= jointprob_facetgrid(viz_fork_joint(), W, T,
fig ="W", label_dim1="T", title=r"$P(W, T)$", cmap="inferno") label_dim0
But, when we stratify by R, now {W \perp T \mid R}.
from jax.scipy.stats.norm import pdf as normpdf
= jax.jit(normpdf)
normpdfjit
= jnp.linspace(-3, 3, 11)
W = jnp.linspace(-3, 3, 11)
T = jnp.arange(2)
R
@memo
def viz_fork_stratified[_w: W, _t: T, _r: R]():
agent: knows(_w, _t)
agent: thinks[in R, wpp=1),
friend: chooses(r in W, wpp=normpdfjit(w, 2*r-1, 0.5)),
friend: chooses(w in T, wpp=normpdfjit(t, 2*r-1, 0.5)),
friend: chooses(t
]is _r
agent: observes [friend.r] return agent[Pr[friend.w == _w, friend.t == _t]]
= viz_fork_stratified()
res
= jointprob_facetgrid(
fig 0],
res[:, :,
W,
T, ="W",
label_dim0="T",
label_dim1=r"$P(W, T \mid R{=}0)$",
title="inferno")
cmap= jointprob_facetgrid(
fig 1],
res[:, :,
W,
T, ="W",
label_dim0="T",
label_dim1=r"$P(W, T \mid R{=}1)$",
title="inferno") cmap
Pipe
Screening off is a purely statistical phenomenon. For example, consider the causal chain model, where A directly causes Z, which in turn directly causes B (Z is a “mediator”).
@memo
def pipe[_a: A, _b: B]():
agent: knows(_a, _b)
agent: thinks[
friend: knows(_a, _b),in A, wpp=1),
friend: chooses(a in Z, wpp=(
friend: chooses(z 1 if a == {A.TRUE}
else (0.9 if z == {Z.TRUE} else 0.1))),
in B, wpp=(
friend: chooses(b
(0.1 if b == {B.TRUE} else 0.9
if z == {Z.TRUE}
) else (0.4 if b == {B.TRUE} else 0.6))),
]### this observes statement is just here to
### make it easier to inspect the results
is _b
agent: observes [friend.b] return agent[Pr[friend.a == _a]]
= pipe(print_table=True) _
+-------+-------+----------------------+
| _a: A | _b: B | pipe |
+-------+-------+----------------------+
| FALSE | FALSE | 0.5370370149612427 |
| FALSE | TRUE | 0.34210526943206787 |
| TRUE | FALSE | 0.4629629850387573 |
| TRUE | TRUE | 0.6578947305679321 |
+-------+-------+----------------------+
We can observe that A and B are associated.
But observing Z, the event that mediates an indirect causal relation between A and B, renders A and B independent. A and B are still causally dependent in our model, it is just our beliefs about the states of A and B that become conditionally independent.
@memo
def pipe[_z: Z, _a: A, _b: B]():
agent: knows(_a, _b, _z)
agent: thinks[
friend: knows(_a, _b),in A, wpp=1),
friend: chooses(a in Z, wpp=(
friend: chooses(z 1 if a == {A.TRUE}
else (0.9 if z == {Z.TRUE} else 0.1))),
in B, wpp=(
friend: chooses(b
(0.1 if b == {B.TRUE} else 0.9
if z == {Z.TRUE}
) else (0.4 if b == {B.TRUE} else 0.6))),
]
is _z
agent: observes [friend.z]
### this observes statement is just here to
### make it easier to inspect the results
is _b
agent: observes [friend.b]
return agent[Pr[friend.a == _a]]
= pipe(print_table=True) _
+-------+-------+-------+---------------------+
| _z: Z | _a: A | _b: B | pipe |
+-------+-------+-------+---------------------+
| FALSE | FALSE | FALSE | 0.1666666567325592 |
| FALSE | FALSE | TRUE | 0.1666666865348816 |
| FALSE | TRUE | FALSE | 0.8333333134651184 |
| FALSE | TRUE | TRUE | 0.8333333134651184 |
| TRUE | FALSE | FALSE | 0.6428571343421936 |
| TRUE | FALSE | TRUE | 0.6428571343421936 |
| TRUE | TRUE | FALSE | 0.3571428656578064 |
| TRUE | TRUE | TRUE | 0.3571428656578064 |
+-------+-------+-------+---------------------+
Explaining away
“Explaining away” (Pearl, 2014) refers to a complementary pattern of statistical inference which is somewhat more subtle than screening off. If two events A and B are statistically (and hence causally) independent, but they are both causes of one or more other events Z, then conditioning on (observing) Z can render A and B statistically dependent. Here is an example where A and B have a common effect (they collide on Z):
@memo
def collider[_z: Z, _a: A, _b: B]():
agent: knows(_a, _b, _z)
agent: thinks[
friend: knows(_a, _b),in A, wpp=1),
friend: chooses(a in B, wpp=1),
friend: chooses(b in Z, wpp=(
friend: chooses(z
(0.1 if z == {Z.TRUE} else 0.9
if (a == {A.TRUE} or b == {B.TRUE})
) else (0.2 if z == {Z.TRUE} else 0.8))),
]
is _z
agent: observes [friend.z]
### this observes statement is just here to
### make it easier to inspect the results
is _b
agent: observes [friend.b]
return agent[Pr[friend.a == _a]]
= collider(print_table=True) _
+-------+-------+-------+---------------------+
| _z: Z | _a: A | _b: B | collider |
+-------+-------+-------+---------------------+
| FALSE | FALSE | FALSE | 0.4705882668495178 |
| FALSE | FALSE | TRUE | 0.5 |
| FALSE | TRUE | FALSE | 0.529411792755127 |
| FALSE | TRUE | TRUE | 0.5 |
| TRUE | FALSE | FALSE | 0.6666666269302368 |
| TRUE | FALSE | TRUE | 0.5 |
| TRUE | TRUE | FALSE | 0.3333333134651184 |
| TRUE | TRUE | TRUE | 0.5 |
+-------+-------+-------+---------------------+
As with screening off, we only induce statistical dependence from learning about Z
, not causal dependence: when we observe Z
, A
and B
remain causally independent in our model of the world; it is our beliefs about A and B that become statistically dependent.
The most typical pattern of explaining away we see in causal reasoning is a kind of anti-correlation: the probabilities of two possible causes for the same effect increase when the effect is observed, but they are conditionally anti-correlated, so that observing additional evidence in favor of one cause should lower our degree of belief in the other cause. (This pattern is where the term explaining away comes from.) However, the coupling induced by conditioning on common effects depends on the nature of the interaction between the causes, it is not always an anti-correlation. Explaining away takes the form of an anti-correlation when the causes interact in a roughly disjunctive or additive form: the effect tends to happen if any cause happens; or the effect happens if the sum of some continuous influences exceeds a threshold. The following simple mathematical examples show this and other patterns.
The model below defines two independent variables X and Y both of which are used to define the value of our data. Suppose we condition on observing the sum of two integers drawn uniformly from 0 to 9:
from jax.scipy.stats.norm import pdf as normpdf
= jax.jit(normpdf)
normpdfjit
= jnp.arange(10)
X = jnp.arange(10)
Y
@memo
def f[_x: X, _y: Y]():
agent: knows(_x, _y)
agent: thinks[
restaurant: knows(_x, _y),in X, wpp=1),
selection: chooses(x in Y, wpp=1),
selection: chooses(y
]+ selection.y == 9 ]
agent: observes_that[ selection.x
return agent[Pr[selection.x == _x, selection.y == _y]]
= jointprob_facetgrid(
g
f(),
X,
Y, ="X",
label_dim0="Y",
label_dim1=r"$P(X, Y \mid X+Y=9)$",
title="inferno") cmap
This gives perfect anti-correlation in conditional inferences for X
and Y
. But suppose we instead condition on observing that X
and Y
are equal:
@memo
def ff[_x: X, _y: Y]():
agent: knows(_x, _y)
agent: thinks[
restaurant: knows(_x, _y),in X, wpp=1),
selection: chooses(x in Y, wpp=1),
selection: chooses(y
]== selection.y ]
agent: observes_that[ selection.x
return agent[Pr[selection.x == _x, selection.y == _y]]
= jointprob_facetgrid(
g
ff(),
X,
Y, ="X",
label_dim0="Y",
label_dim1=r"$P(X, Y \mid X=Y)$",
title="inferno") cmap
Now, of course, X and Y go from being independent a priori to being perfectly correlated in the conditional distribution. Try out these other conditions to see other possible patterns of conditional dependence for a priori independent functions:
selection.x - selection.y < 2
selection.x + selection.y >= 9 and selection.x + selection.y <= 11
abs(selection.x - selection.y) == 3
(selection.x - selection.y) % 10 == 3
selection.x % 2 == selection.y % 2
selection.x % 5 == selection.y % 5
selection.x % 2 == selection.y % 3
@memo
def fff[_x: X, _y: Y]():
agent: knows(_x, _y)
agent: thinks[
restaurant: knows(_x, _y),in X, wpp=1),
selection: chooses(x in Y, wpp=1),
selection: chooses(y
]== selection.y ]
agent: observes_that[ selection.x
return agent[Pr[selection.x == _x, selection.y == _y]]
= jointprob_facetgrid(
g
fff(),
X,
Y, ="X",
label_dim0="Y",
label_dim1=r"Joint Posterior",
title="inferno") cmap
Collider bias - selection and survivorship bias
Restaurants
This model looks at the effect of conditioning on a collider.
It not uncommon that restaurants with amazing food require some effort to get to, while convenient locations are filled with mediocre restaurants that somehow stay in business.
A restaurant’s popularity depends on both on its food and location. Restaurants with exceptional food can thrive in out-of-the-way locations because people will make the extra effort to visit them. Restaurants in prime locations (think busy downtowns and tourist hotspots) can get away with serving subpar food because they get plenty of foot traffic or it’s hard for people to go elsewhere (think airports, music festivals, and rural college campuses).
The model represents this by measuring both food quality and location on a scale where 0 is “average,” positive numbers are “better than average,” and negative numbers are “worse than average.” Popularity is modeled as the combination of food and location. The popularity_value
function sums the food and location values for a given restaurant. A restaurant is more likely to be popular if the food is good and if it’s easy to get to. This function also represents the idea that popularity isn’t perfectly determined by food and location alone. There are other factors that influence a restaurant’s popularity (things like ambiance, live music and staff). These other factors are modeled as Gaussian noise centered on \text{food} + \text{location}:
\begin{align*} \text{food} ~\sim& ~~\mathcal{N}(\mu{=}0, \sigma_{\text{food}}) \\ \text{location} ~\sim& ~~\mathcal{N}(\mu{=}0, \sigma_{\text{location}}) \\ \text{popularity} ~\sim& ~~\mathcal{N}(\mu{=}\text{food}+\text{location}, \sigma_{\text{popularity}}) \end{align*}
from jax.scipy.stats.norm import pdf as normpdf
= jax.jit(normpdf)
normpdfjit
= jnp.linspace(-2, 2, 11)
Food = jnp.linspace(-2, 2, 11)
Location = product(
Popularity =len(Food),
food=len(Location),
loc
)
@jax.jit
def popularity_value(p):
= Popularity.food(p)
y_food = Popularity.loc(p)
y_loc return y_food + y_loc
@jax.jit
def popularity_pdf(p, f, l):
return normpdf(popularity_value(p), f + l, 0.4)
@memo
def critic_joint[_f: Food, _l: Location](cutoff=1):
critic: knows(_f, _l)
critic: thinks[
restaurant: knows(_f, _l),in Food, wpp=normpdfjit(f, 0, 0.4)),
restaurant: chooses(f in Location, wpp=normpdfjit(l, 0, 0.4)),
restaurant: chooses(l
clientele: knows(restaurant.f, restaurant.l),in Popularity, wpp=popularity_pdf(p, restaurant.f, restaurant.l))
clientele: chooses(p
]
return critic[Pr[restaurant.f == _f, restaurant.l == _l]]
= jointprob_facetgrid(
g
critic_joint(),
Food,
Location, ="Location",
label_dim0="Food",
label_dim1=r"$P(Location,Food)$",
title="inferno") cmap
This plot shows the probability density of { P(\text{location}, \text{food}) }, as well as the marginal distributions { P(\text{location}) } above and { P(\text{food}) } to the right. The dashed red line is the correlation between Location and Food.
As you can see, there’s no a priori association between the quality of food and location.
However, in order for a restaurant to survive, it must attract a certain number of clientele.
In the model below, we introduction a selection criterion. Restaurants that do not exceed the minimum popularity cutoff
go out of business.
Below, we condition the model on {\text{popularity} > 1}. Restaurants more popular than the cutoff survive while the others go out of business. We depict this in the DAG by shading the “popularity” node.
@memo
def critic_posterior[_f: Food, _l: Location](cutoff=1):
critic: knows(_f, _l)
critic: thinks[
restaurant: knows(_f, _l),in Food, wpp=normpdfjit(f, 0, 0.4)),
restaurant: chooses(f in Location, wpp=normpdfjit(l, 0, 0.4)),
restaurant: chooses(l
clientele: knows(restaurant.f, restaurant.l),in Popularity, wpp=popularity_pdf(p, restaurant.f, restaurant.l))
clientele: chooses(p
]
> cutoff ]
critic: observes_that[ popularity_value(clientele.p)
return critic[Pr[restaurant.f == _f, restaurant.l == _l]]
= 1
cutoff_ = jointprob_facetgrid(
g =cutoff_),
critic_posterior(cutoff
Food,
Location, ="Location",
label_dim0="Food",
label_dim1=fr"$P(Location,Food \mid popularity > {cutoff_})$",
title="inferno") cmap
The selection criterion induces a negative correlation between Location and Food. Try changing the cutoff value and see what happens to the correlation. The more stringent the selection, the stronger the correlation.
Really good restaurants can survive in bad locations because they’re worth the effort of getting to, and terrible restaurants can survive in Hanover because its hard for the population to go elsewhere.
This simple example illustrates a very common phenomenon in causal inference: selection on a variable changes the associations among its causal factors.
%reset -f
import sys
import platform
import importlib.metadata
print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())
print("\nPackages:")
for name, version in sorted(
"Name"], dist.version) for dist in importlib.metadata.distributions()),
((dist.metadata[=lambda x: x[0].lower() # Sort case-insensitively
key
):print(f"{name}=={version}")
Python: 3.13.2 (main, Feb 5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64
Packages:
annotated-types==0.7.0
anyio==4.9.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
astroid==3.3.9
asttokens==3.0.0
async-lru==2.0.5
attrs==25.3.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
dill==0.3.9
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.18.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.9
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.2
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
isort==6.0.1
jax==0.5.3
jaxlib==0.5.3
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.6
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
mccabe==0.7.0
memo-lang==1.1.2
mistune==3.1.3
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.4
opt_einsum==3.4.0
optype==0.9.2
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.250308
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.7
plotly==5.24.1
pre_commit==4.2.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pylint==3.3.6
pyparsing==3.2.3
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-json-logger==3.3.0
pytz==2025.2
PyYAML==6.0.2
pyzmq==26.3.0
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.11.2
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.1
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==78.1.0
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.39
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.6.0
tinycss2==1.4.0
toml==0.10.2
tomlkit==0.13.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250318
typing_extensions==4.12.2
tzdata==2025.2
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.3.0
zipp==3.21.0