Bayes’ rule

Imagine you have a bag with 10 coins in it. Of these coins, 9 are fair and 1 is biased with {P(\text{H}) = 0.7}. The biased coin is painted red.

You pull a coin from the bag, look at it, and flip it three times. You record the sequence and whether it was the biased coin or a fair coin. You then put the coin back in the bag. You repeat this thousands of times until you have an estimate of the probability of every sequence, both for the biased coin and for a fair coin.

You now pass the bag to a friend and close your eyes. Your friend draws a coin, flips it three times, and tells you it produced a sequence with two \text{H}. What’s the probability that your friend flipped the biased coin?

Analyzing the Experiment

Based on your many coin flips, you can analyze the possible outcomes of this experiment.

experiment / trial
An experiment or trial is any procedure that can be infinitely repeated and has a well-defined set of possible outcomes, known as the sample space.
outcome
An outcome is a possible result of an experiment or trial. Each possible outcome of a particular experiment is unique, and different outcomes are mutually exclusive (only one outcome will occur on each trial of the experiment). All of the possible outcomes of an experiment form the elements of a sample space.
event
An event is a set of outcome. Since individual outcomes may be of little practical interest, or because there may be prohibitively (even infinitely) many of them, outcomes are grouped into sets of outcomes that satisfy some condition, which are called “events.” A single outcome can be a part of many different events.
sample space
The sample space of an experiment or random trial is the set of all possible outcomes or results of that experiment.
  • Sample Space for Coin Flips: The sample space of the sequence of three flips of a coin is S = \{\text{H},\text{T}\}^3 = \{\text{HHH}, \text{HHT}, \text{HTH}, \text{THH}, \text{HTT}, \text{THT}, \text{TTH}, \text{TTT} \} This set lists every possible sequence of heads (\text{H}) and tails (\text{T}) from three flips.

  • Sample Space for Coin Selection: The sample space of drawing a coin from the bag is C = \{\text{fair}, \text{biased}\}

  • Overall Sample Space: The sample space of the entire experiment (drawing a coin and flipping it) is the product of these sample spaces: \Omega = C \times S = \{(c, s) : c \in C, s \in S \} This means each outcome in the experiment is a pair: (type of coin, sequence of flips).

Visualizing with Sets

We can think of the sample space \Omega as the set of all of the outcomes possible in the experiment. Each point in the sample space corresponds to a unique outcome. We can draw a circle around all of the points that are described by an event, for instance that the sequences has exactly 2 \text{H}. This set of points is the data (d). In this example, observing d means that we assume (in a probabilistic sense) the event (that there were 2 HEADS) has occurred. That event comprises all the the outcomes (a sequence of three coin flips) where exactly two of the flips came up HEADS.

We can draw another circle around all of the points that correspond to sequences produced by the biased coin. We’ll call this set our hypothesis (h) that the coin your friend flipped was biased.

What we want to know is {P(h \mid d)}, i.e., {P(\text{biased} \mid 2~\text{H})}. This is the probability that the coin is biased given that we observed two HEADS.

Deriving Bayes’ Rule

The intersection of h and d represents all the outcomes where the biased coin was flipped and the sequence contains exactly two HEADS. We can write this probability as {P(h \cap d)}, or equivalently, P(h, d).

To find {P(h \mid d)}, we want the proportion of d that is occupied by h. This is the ratio of the total measure of the intersection {h \cap d} to the total measure of d:

P(h \mid d) = \frac{P(h, d)}{P(d)}

This is the definition of conditional probability. It holds whenever {P(d) > 0}.

Algebraic rearrangement yields

P(h, d) = P(d) \, P(h \mid d)

This is the chain rule. It says that we can calculate the measure of the intersection by multiplying the proportion of d occupied by h, by the total measure of d.

The chain rule is symmetric, so we can write both

\begin{align*} P(h, d) = P(d) \, P(h \mid d) \\ \\ P(h, d) = P(h) \, P(d \mid h) \end{align*}

Returning to the quantity that we want to know,

P(h \mid d) = \frac{P(h, d)}{P(d)}

Substitution of {P(h, d)} with {P(h) \; P(d \mid h)} yields

P(h \mid d) = \frac{P(h) \, P(d \mid h)}{P(d)}

This is Bayes’ rule. It follows directly from the definition of conditional probability. Bayes’ rule is just probability.

There are names for each of the terms.

In the case of this example:

  • Posterior {P(h \mid d)}: The probability that the coin is biased given that it produced the data d (a sequence with exactly 2 HEADS). This is what we want to know.
  • Prior {P(h)}: The probability that the coin is biased before seeing any data. This is 1/10 (since there’s 1 biased coin out of 10).
  • Likelihood {P(d \mid h)}: The probability of getting 2 HEADS in 3 flips given that the coin is biased. We’ll calculate this below.
  • Evidence {P(d)}: The overall probability of getting 2 HEADS (from either coin). We’ll also calculate this below.

Applying Bayes’ Rule

On paper

Let’s now fill in our probabilities to answer the question.

We know that,

P(h) = \frac{1}{10} = 0.1

since there is 1 biased coin and 9 fair coins.

We also know that for the biased coin,

P(d \mid h) = (3) \cdot (0.7)^2 \cdot (0.3) = 0.441

since set S contains 3 elements with exactly 2 HEADS.

The probability of getting 2 HEADS in 3 flips follows a binomial distribution, thus

P(d \mid h) = \binom{3}{2} \cdot (0.7)^2 \cdot (1 - 0.7)

To calculate {P(d)}, we need to consider both the case where the coin is biased and the case where the coin is fair.

P(d) = P(h) \, P(d \mid h) + P(\neg h) \, P(d \mid \neg h)

Or more generally (e.g. when there are not just two mutually exclusive hypotheses):

P(d) = \sum_{h' \in \mathcal{H}} P(h') \, P(d \mid h')

For this coin example, h is the coin was biased and \neg h is the coin was fair:

\begin{align*} P(d) &= P(d, h) + P(d, \neg h) \\ &= P(h) \, P(d \mid h) + P(\neg h) \, P(d \mid \neg h) \\ &= (0.1) \cdot (0.441) + (0.9) \cdot \binom{3}{2} \cdot (0.5)^2 \cdot (1 - 0.5)^1 \\ &= (0.1) \cdot (0.441) + (0.9) \cdot (3) \cdot (0.5)^2 \cdot (0.5) \\ &= 0.0441 + 0.3375 \\ &= 0.3816 \end{align*}

Finally, we can plug everything into Bayes’ rule:

\begin{align*} P(h \mid d) &= \frac{P(h) \, P(d \mid h)}{P(d)} \\ \\ &= \frac{0.1 \cdot 0.441}{0.3816} \\ \\ &\approx 0.1156 \end{align*}

Ergo, the probability that your friend flipped the biased coin is approximately 0.1156.

In the memo PPL

import jax
import jax.numpy as jnp
from memo import memo
from memo import domain as product
from enum import IntEnum

class Coin(IntEnum):
    TAILS = 0
    HEADS = 1

class Bag(IntEnum):
    FAIR = 0
    BIASED = 1

nflips = 3

S = product(**{f"f{flip+1}": len(Coin) for flip in range(nflips)})

NumHeads = jnp.arange(nflips + 1)

@jax.jit
def sum_seq(s):
    return S.f1(s) + S.f2(s) + S.f3(s)

@jax.jit
def pmf(s, c):
    ### probability of heads for fair and biased coin
    p_h = jnp.array([0.5, 0.7])[c]
    ### P(T) and P(H) for coin c
    p = jnp.array([1 - p_h, p_h])
    ### probability of the outcome of each flip
    p1 = p[S.f1(s)]
    p2 = p[S.f2(s)]
    p3 = p[S.f3(s)]
    ### probability of the sequence s
    return p1 * p2 * p3

@memo
def experiment[_numheads: NumHeads]():
    ### observer's mental model of the process by which
    ### the friend determined the number of heads
    observer: thinks[
        ### friend draws a coin from the bag
        friend: chooses(c in Bag, wpp=0.1 if c == {Bag.BIASED} else 0.9),
        ### friend flips the coin 3x
        friend: given(s in S, wpp=pmf(s, c)),
        ### friend counts the number of HEADS
        friend: given(numheads in NumHeads, wpp=(numheads==sum_seq(s)))
    ]
    ### observer learns the number of heads from friend
    observer: observes [friend.numheads] is _numheads

    ### query the observer: what's the probability that 
    ### the coin c, which your friend flipped, was biased?
    return observer[Pr[friend.c == {Bag.BIASED}]]

xa = experiment(print_table=True, return_xarray=True).aux.xarray

numheads = 2
print(f"\n\nP(h=biased | d={numheads}heads) = {xa.loc[numheads].item()}\n")
+---------------------+-----------------------+
| _numheads: NumHeads | experiment            |
+---------------------+-----------------------+
| 0                   | 0.0234375037252903    |
| 1                   | 0.053030312061309814  |
| 2                   | 0.11556604504585266   |
| 3                   | 0.23365122079849243   |
+---------------------+-----------------------+


P(h=biased | d=2heads) = 0.11556604504585266

%reset -f
import sys
import platform
import importlib.metadata

print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())

print("\nPackages:")
for name, version in sorted(
    ((dist.metadata["Name"], dist.version) for dist in importlib.metadata.distributions()),
    key=lambda x: x[0].lower()  # Sort case-insensitively
):
    print(f"{name}=={version}")
Python: 3.13.2 (main, Feb  5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64

Packages:
annotated-types==0.7.0
anyio==4.8.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
async-lru==2.0.4
attrs==25.1.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.17.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.8
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.1
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.5.2
jaxlib==0.5.1
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
memo-lang==1.1.0
mistune==3.1.2
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.3
opt_einsum==3.4.0
optype==0.9.1
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.241126
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.6
plotly==5.24.1
pre_commit==4.1.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pyparsing==3.2.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==3.3.0
pytz==2025.1
PyYAML==6.0.2
pyzmq==26.2.1
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.9.10
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.0
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==75.8.2
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.38
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.4.0
toml==0.10.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250204
typing_extensions==4.12.2
tzdata==2025.1
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.1.2
zipp==3.21.0