Imagine you have a bag with 10 coins in it. Of these coins, 9 are fair and 1 is biased with {P(\text{H}) = 0.7}. The biased coin is painted red.
You pull a coin from the bag, look at it, and flip it three times. You record the sequence and whether it was the biased coin or a fair coin. You then put the coin back in the bag. You repeat this thousands of times until you have an estimate of the probability of every sequence, both for the biased coin and for a fair coin.
You now pass the bag to a friend and close your eyes. Your friend draws a coin, flips it three times, and tells you it produced a sequence with two \text{H}. What’s the probability that your friend flipped the biased coin?
Analyzing the Experiment
Based on your many coin flips, you can analyze the possible outcomes of this experiment.
An outcome is a possible result of an experiment or trial. Each possible outcome of a particular experiment is unique, and different outcomes are mutually exclusive (only one outcome will occur on each trial of the experiment). All of the possible outcomes of an experiment form the elements of a sample space.
An event is a set of outcome. Since individual outcomes may be of little practical interest, or because there may be prohibitively (even infinitely) many of them, outcomes are grouped into sets of outcomes that satisfy some condition, which are called “events.” A single outcome can be a part of many different events.
The sample space of an experiment or random trial is the set of all possible outcomes or results of that experiment.
Sample Space for Coin Flips: The sample space of the sequence of three flips of a coin is
S = \{\text{H},\text{T}\}^3 = \{\text{HHH}, \text{HHT}, \text{HTH}, \text{THH}, \text{HTT}, \text{THT}, \text{TTH}, \text{TTT} \}
This set lists every possible sequence of heads (\text{H}) and tails (\text{T}) from three flips.
Sample Space for Coin Selection: The sample space of drawing a coin from the bag is
C = \{\text{fair}, \text{biased}\}
Overall Sample Space: The sample space of the entire experiment (drawing a coin and flipping it) is the product of these sample spaces:
\Omega = C \times S = \{(c, s) : c \in C, s \in S \}
This means each outcome in the experiment is a pair: (type of coin, sequence of flips).
Visualizing with Sets
We can think of the sample space\Omega as the set of all of the outcomes possible in the experiment. Each point in the sample space corresponds to a unique outcome. We can draw a circle around all of the points that are described by an event, for instance that the sequences has exactly 2 \text{H}. This set of points is the data (d). In this example, observing d means that we assume (in a probabilistic sense) the event (that there were 2 HEADS) has occurred. That event comprises all the the outcomes (a sequence of three coin flips) where exactly two of the flips came up HEADS.
We can draw another circle around all of the points that correspond to sequences produced by the biased coin. We’ll call this set our hypothesis (h) that the coin your friend flipped was biased.
What we want to know is {P(h \mid d)}, i.e., {P(\text{biased} \mid 2~\text{H})}. This is the probability that the coin is biased given that we observed two HEADS.
Deriving Bayes’ Rule
The intersection of h and d represents all the outcomes where the biased coin was flipped and the sequence contains exactly two HEADS. We can write this probability as {P(h \cap d)}, or equivalently, P(h, d).
To find {P(h \mid d)}, we want the proportion of d that is occupied by h. This is the ratio of the total measure of the intersection {h \cap d} to the total measure of d:
P(h \mid d) = \frac{P(h, d)}{P(d)}
This is the definition of conditional probability. It holds whenever {P(d) > 0}.
Algebraic rearrangement yields
P(h, d) = P(d) \, P(h \mid d)
This is the chain rule. It says that we can calculate the measure of the intersection by multiplying the proportion of d occupied by h, by the total measure of d.
Substitution of {P(h, d)} with {P(h) \; P(d \mid h)} yields
P(h \mid d) = \frac{P(h) \, P(d \mid h)}{P(d)}
This is Bayes’ rule. It follows directly from the definition of conditional probability. Bayes’ rule is just probability.
There are names for each of the terms.
In the case of this example:
Posterior {P(h \mid d)}: The probability that the coin is biased given that it produced the data d (a sequence with exactly 2 HEADS). This is what we want to know.
Prior {P(h)}: The probability that the coin is biased before seeing any data. This is 1/10 (since there’s 1 biased coin out of 10).
Likelihood {P(d \mid h)}: The probability of getting 2 HEADS in 3 flips given that the coin is biased. We’ll calculate this below.
Evidence {P(d)}: The overall probability of getting 2 HEADS (from either coin). We’ll also calculate this below.
Applying Bayes’ Rule
On paper
Let’s now fill in our probabilities to answer the question.
Ergo, the probability that your friend flipped the biased coin is approximately 0.1156.
In the memo PPL
import jaximport jax.numpy as jnpfrom memo import memofrom memo import domain as productfrom enum import IntEnumclass Coin(IntEnum): TAILS =0 HEADS =1class Bag(IntEnum): FAIR =0 BIASED =1nflips =3S = product(**{f"f{flip+1}": len(Coin) for flip inrange(nflips)})NumHeads = jnp.arange(nflips +1)@jax.jitdef sum_seq(s):return S.f1(s) + S.f2(s) + S.f3(s)@jax.jitdef pmf(s, c):### probability of heads for fair and biased coin p_h = jnp.array([0.5, 0.7])[c]### P(T) and P(H) for coin c p = jnp.array([1- p_h, p_h])### probability of the outcome of each flip p1 = p[S.f1(s)] p2 = p[S.f2(s)] p3 = p[S.f3(s)]### probability of the sequence sreturn p1 * p2 * p3@memodef experiment[_numheads: NumHeads]():### observer's mental model of the process by which### the friend determined the number of heads observer: thinks[### friend draws a coin from the bag friend: chooses(c in Bag, wpp=0.1if c == {Bag.BIASED} else0.9),### friend flips the coin 3x friend: given(s in S, wpp=pmf(s, c)),### friend counts the number of HEADS friend: given(numheads in NumHeads, wpp=(numheads==sum_seq(s))) ]### observer learns the number of heads from friend observer: observes [friend.numheads] is _numheads### query the observer: what's the probability that ### the coin c, which your friend flipped, was biased?return observer[Pr[friend.c == {Bag.BIASED}]]xa = experiment(print_table=True, return_xarray=True).aux.xarraynumheads =2print(f"\n\nP(h=biased | d={numheads}heads) = {xa.loc[numheads].item()}\n")