{ "cells": [ { "cell_type": "markdown", "id": "90ab98b8-e26d-4cab-bc32-bf7cdee4d631", "metadata": {}, "source": [ "# Generative Models 1\n", "\n", "Ontological modeling — formalizing a hypothesis of the data-generating\n", "process\n", "\n", "## Models, simulation, and degrees of belief\n", "\n", "One view of knowledge is that the mind maintains working models of parts\n", "of the world. ‘Model’ in the sense that it captures some of the\n", "structure in the world, but not all (and what it captures need not be\n", "exactly what is in the world—just what is useful). ‘Working’ in the\n", "sense that it can be used to simulate this part of the world, imagining\n", "what will follow from different initial conditions. As an example take\n", "the Plinko machine: a box with uniformly spaced pegs, with bins at the\n", "bottom. Into this box we can drop marbles:\n", "\n", "The plinko machine is a ‘working model’ for many physical processes in\n", "which many small perturbations accumulate—for instance a leaf falling\n", "from a tree. It is an approximation to these systems because we use a\n", "discrete grid (the pegs) and discrete bins. Yet it is useful as a model:\n", "for instance, we can ask where we expect a marble to end up depending on\n", "where we drop it in, by running the machine several times—simulating the\n", "outcome.\n", "\n", "Imagine that someone has dropped a marble into the plinko machine;\n", "before looking at the outcome, you can probably report how much you\n", "believe that the ball has landed in each possible bin. Indeed, if you\n", "run the plinko machine many times, you will see a shape emerge in the\n", "bins. The number of balls in a bin gives you some idea of how much you\n", "should expect a new marble to end up there. This ‘shape of expected\n", "outcomes’ can be formalized as a probability distribution (described\n", "below). Indeed, there is an intimate connection between simulation,\n", "expectation or belief, and probability, which we explore in the rest of\n", "this section.\n", "\n", "There is one more thing to note about our Plinko machine above: we are\n", "using a computer program to simulate the simulation. Computers can be\n", "seen as universal simulators. How can we, clearly and precisely,\n", "describe the simulation we want a computer to do?\n", "\n", "## Building Generative Models\n", "\n", "We wish to describe in formal terms how to generate states of the world.\n", "That is, we wish to describe the causal process, or steps that unfold,\n", "leading to some potentially observable states. The key idea of this\n", "section is that these generative processes can be described as\n", "*computations*—computations that involve random choices to capture\n", "uncertainty about the process.\n", "\n", "Programming languages are formal systems for describing what\n", "(deterministic) computation a computer should do. Modern programming\n", "languages offer a wide variety of different ways to describe\n", "computation; each makes some processes simple to describe and others\n", "more complex. However, a key tenet of computer science is that all of\n", "these languages have the same fundamental power: any computation that\n", "can be described with one programming language can be described by\n", "another. (More technically this [Church-Turing\n", "thesis](https://en.wikipedia.org/wiki/Church%E2%80%93Turing_thesis)\n", "posits that many specific computational systems capture the set of all\n", "effectively computable procedures. These are called universal systems.)\n", "\n", "------------------------------------------------------------------------\n", "\n", "## Two approaches to generative modeling\n", "\n", "Consider how we might simulate a coin being flipped, as random samples\n", "from a Bernoulli distribution." ] }, { "cell_type": "code", "execution_count": null, "id": "9df21702", "metadata": {}, "outputs": [], "source": [ "from scipy.stats import bernoulli\n", "from numpy.random import seed\n", "\n", "ACoin = bernoulli(0.5)" ] }, { "cell_type": "markdown", "id": "f9e4505a-8ca6-4056-af31-956efe561cf9", "metadata": {}, "source": [ "If you run `ACoin.rvs()` multiple times you’ll see that you get `0`\n", "($\\text{TAILS}$) sometimes and `1` ($\\text{HEADS}$) sometimes." ] }, { "cell_type": "code", "execution_count": null, "id": "239dea07", "metadata": {}, "outputs": [], "source": [ "ACoin.rvs()" ] }, { "cell_type": "markdown", "id": "fc0cddd7-5ef3-46b8-b4b7-93d81e5c4406", "metadata": {}, "source": [ "But what’s happening under the hood? It becomes clearer when we set a\n", "random seed:" ] }, { "cell_type": "code", "execution_count": null, "id": "f0393174", "metadata": {}, "outputs": [], "source": [ "seed(100)\n", "ACoin.rvs()\n", "### uncomment to take multiple samples ###\n", "# ACoin.rvs()\n", "# ACoin.rvs()" ] }, { "cell_type": "markdown", "id": "7f1336ef-fcd9-4e94-9e31-e0a47cdf5483", "metadata": {}, "source": [ "When you run this cell multiple times what do you see? There is no more\n", "randomness. Our simulated coin always comes up $\\text{HEADS}$.\n", "\n", "Of course, this only happens because we set the random seed to the same\n", "value right before drawing each sample. If we were to sample `ACoin`\n", "multiple times without resetting the seed, we would draw different\n", "values, and in the limit of infinite samples, the proportion of\n", "$\\text{TAILS}$ and $\\text{HEADS}$ would be equal.\n", "\n", "This trivial example illustrates a property of probability. There is\n", "nothing random about probability distributions. When we write\n", "`bernoulli(0.5)` we’re assigning probability mass to subsets of the\n", "outcomes $\\{\\text{TAILS}, \\text{HEADS}\\}$. When we draw a sample by\n", "calling `.rvs()`, a (pseudo)random number is passed to a deterministic\n", "function that maps the state space of the random number generator to\n", "subsets of the outcomes.\n", "\n", "In other words, one way of building generative models involves drawing\n", "samples according to specified distributions and collecting the results.\n", "For instance, let’s simulate 1000 flips of two fair coins and calculate\n", "how often they both come up $\\text{HEADS}$:" ] }, { "cell_type": "code", "execution_count": null, "id": "797e2539", "metadata": {}, "outputs": [], "source": [ "seed(100)\n", "FairCoin1 = bernoulli(0.5)\n", "FairCoin2 = bernoulli(0.5)\n", "n = 1000\n", "both_heads = 0\n", "for i in range(n):\n", " if FairCoin1.rvs() == 1 and FairCoin2.rvs() == 1:\n", " both_heads += 1\n", "\n", "print(f\"The coins both came up HEADS in {both_heads/n:0.3} proportion of the trials\")" ] }, { "cell_type": "markdown", "id": "ef71e779-cfa2-4fb3-bd4c-78961381df5d", "metadata": {}, "source": [ "But it also often possible and preferable to calculate probabilities of\n", "interest directly. This is the approach taken by `memo`.\n", "\n", "`memo` enables blisteringly fast generative modeling by compiling\n", "probabilistic models down to [JAX](https://jax.readthedocs.io/) *array\n", "programs* (Chandra et al. 2025).\n", "\n", "Let’s explore how `memo` flips coins.\n", "\n", "We’ll start by defining the [sample\n", "space](https://en.wikipedia.org/wiki/Sample_space) of a coin:\n", "$S = \\{ T, H \\}$ (where $T$ and $H$ are $\\text{TAILS}$ and\n", "$\\text{HEADS}$, which are represented by $0$ and $1$, respectively)." ] }, { "cell_type": "code", "execution_count": null, "id": "faafdd9b", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "Coin1 = jnp.array(\n", " [\n", " 0, # TAILS,\n", " 1, # HEADS\n", " ]\n", ")" ] }, { "cell_type": "markdown", "id": "7731bb18-620a-4033-9405-637e34a610bb", "metadata": {}, "source": [ "JAX is a pretty amazing feat of engineering that sets a gold standard\n", "for efficiency. While widely used, it is still under active development,\n", "and at present, the focus on speed has involved compromises on safety\n", "and flexibility. For instance, notice that while we can index the JAX\n", "array we defined similar to a `numpy` array," ] }, { "cell_type": "code", "execution_count": null, "id": "67d79382", "metadata": {}, "outputs": [], "source": [ "Coin1[0]\n", "Coin1[1]" ] }, { "cell_type": "markdown", "id": "18205244-d6b9-4608-8f80-b78e15240341", "metadata": {}, "source": [ "but unlike `numpy`, JAX does not prevent us from doing things that we\n", "should not be able to do, like indexing outside of the array:" ] }, { "cell_type": "code", "execution_count": null, "id": "051b86d7", "metadata": {}, "outputs": [], "source": [ "Coin1[2]\n", "Coin1[100]" ] }, { "cell_type": "markdown", "id": "9333c30b-6bbd-4e69-9560-6c65dd48bd8d", "metadata": {}, "source": [ "So when using JAX, it’s especially important to examine, probe and\n", "verify your code thoroughly. Incorrect indexing into multidimensional\n", "arrays are a common mistake, and JAX has few builtin protections. For\n", "more information, you can read about [JAX’s sharp\n", "bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).\n", "\n", "> **JAX - The Sharp Bits**\n", ">\n", "> Information about some of the [Common “Gotchas” in\n", "> JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)\n", "\n", "### Enumeration\n", "\n", "Now let’s write a `memo` model that enumerates over the sample space of\n", "the coin." ] }, { "cell_type": "code", "execution_count": null, "id": "0afcdf05", "metadata": {}, "outputs": [], "source": [ "from memo import memo\n", "\n", "Coin1 = jnp.array([0, 1])\n", "\n", "@memo\n", "def f[_c: Coin1]():\n", " return _c\n", "\n", "f()" ] }, { "cell_type": "markdown", "id": "a74e9171-3cd7-4ee4-81b4-9fd896dc0682", "metadata": {}, "source": [ "We defined `f()` to return the outcome `_c` in `Coin1`, so calling `f()`\n", "returns an array of every realization that `_c` can take. We can get a\n", "nice tabular printout using the `print_table` keyword when we call the\n", "model." ] }, { "cell_type": "code", "execution_count": null, "id": "a8946041", "metadata": {}, "outputs": [], "source": [ "f(print_table=True)" ] }, { "cell_type": "markdown", "id": "dc9bdd01-f621-4242-b9c2-716035d4115a", "metadata": {}, "source": [ "JAX arrays are necessarily numeric, but it would be nice if we could\n", "define that $\\text{TAILS} ::= 0$ and $\\text{HEADS} ::= 1$ for the model.\n", "We can do that using `IntEnum` from the standard package `enum`." ] }, { "cell_type": "code", "execution_count": null, "id": "d2dfa859", "metadata": {}, "outputs": [], "source": [ "from enum import IntEnum\n", "\n", "class Coin(IntEnum):\n", " TAILS = 0\n", " HEADS = 1\n", "\n", "@memo\n", "def f_enum[_c: Coin]():\n", " return _c\n", "\n", "res = f_enum(print_table=True)" ] }, { "cell_type": "markdown", "id": "4b320b55-8d87-4187-a86b-35b9029893a9", "metadata": {}, "source": [ "### Enumeration with probability proportional to (`wpp`)\n", "\n", "Let’s now have `memo` flip the coin. We do this using `given` (or\n", "`chooses`, but we’ll get to that later) by specifying the probability\n", "mass on each outcome. `wpp` stands for “with probability proportional\n", "to” and setting it to 1 means a uniform distribution over\n", "${\\_}{c} \\in \\text{Coin}$.\n", "\n", "A key design principle of `memo` is **encapsulation**, meaning that\n", "information is bound to “**frames**” and is not automatically accessible\n", "from outside the frame. We’ll see how important this architecture is\n", "when we start modeling minds’ mental models of other minds’ mental\n", "models. For now, we’ll define an `observer` frame that represents the\n", "outcome `c` of the `Coin` flip. This information is bound to the\n", "observer’s mind, so we always need to access it within the observer\n", "frame (e.g. with `observer.c`).\n", "\n", "Finally, we enumerate over ${\\_}{c} \\in Coin$ and return the probability\n", "(`Pr[]`) that `_c` was the outcome of the coin toss." ] }, { "cell_type": "code", "execution_count": null, "id": "45fbb975", "metadata": {}, "outputs": [], "source": [ "@memo\n", "def g[_c: Coin]():\n", " observer: given(c in Coin, wpp=1)\n", " return Pr[observer.c == _c]\n", "\n", "res = g(print_table=True)" ] }, { "cell_type": "markdown", "id": "1491421d-86b4-4a8f-9de7-d6f8ae03dce4", "metadata": {}, "source": [ "> **bound and unbound information**\n", ">\n", "> I find it useful to clearly differentiate variables bound to frames\n", "> (e.g. `c` in `observer: given(c in Coin, ...`) from unbound variables\n", "> (e.g. `_c` in `[_c: Coin]`).\n", ">\n", "> This is not strictly necessary, `memo` keeps these separate\n", "> internally, such that one could also write\n", ">\n", "> ``` python\n", "> @memo\n", "> def g[c: Coin]():\n", "> observer: given(c in Coin, wpp=1)\n", "> return Pr[observer.c == c]\n", "> ```\n", ">\n", "> in which case the `c` in `observer: given(c in Coin, ...)` refers the\n", "> bound `c` whereas the `c` in `Pr[... == c]` refers to the unbound `c`.\n", ">\n", "> In this course, I will typically use a convention of denoting unbound\n", "> variables with a leading underscore.\n", "\n", "### Assigning probability mass\n", "\n", "Of course, not all distributions are uniform. We use `wpp` to specify\n", "the probability mass of outcomes.\n", "\n", "E.g., to model a biased coin, we can specify that there’s greater\n", "probability mass on $\\text{TAILS}$ than on $\\text{HEADS}$.\n", "\n", "One way to do this is with a ternary.[^1] Rather than `wpp=1`, we can\n", "write\n", "\n", "[^1]: Ternary conditionals in Python take the form\n", "\n", " ``` python\n", " value_if_true if condition else value_if_false\n", " ```\n", "\n", " E.g.,\n", "\n", " ``` python\n", " even = True if x % 2 == 0 else False\n", " ```\n", "\n", " *Ternary* means “composed of three parts” (from the Latin *ternarius*)." ] }, { "cell_type": "code", "execution_count": null, "id": "6aa729f7", "metadata": {}, "outputs": [], "source": [ "@memo\n", "def f_ternary[_c: Coin]():\n", " observer: given(c in Coin, wpp=0.8 if c == 0 else 0.2)\n", " return Pr[observer.c == _c]\n", "\n", "res = f_ternary(print_table=True)" ] }, { "cell_type": "markdown", "id": "9121f170-904e-48e1-b5c1-ed21159e7369", "metadata": {}, "source": [ "Alternatively, we can define a custom probability mass function as a\n", "`@jax.jit` that we pass as `wpp`." ] }, { "cell_type": "code", "execution_count": null, "id": "59919ea6", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def biased_coin_pmf(c):\n", " return jnp.array([0.8, 0.2])[c]\n", "\n", "@memo\n", "def f_jit[_c: Coin]():\n", " observer: given(c in Coin, wpp=biased_coin_pmf(c))\n", " return Pr[observer.c == _c]\n", "\n", "res = f_jit(print_table=True)" ] }, { "cell_type": "markdown", "id": "2c5f6229-76f3-4d66-95b5-9de1698c0125", "metadata": {}, "source": [ "*Note* that `wpp` normalizes the values passed to it (which is why\n", "`wpp=1` forms a uniform distribution):" ] }, { "cell_type": "code", "execution_count": null, "id": "76e8736a", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def biased_coin_improper_pmf(c):\n", " return jnp.array([16, 4])[c] ### NB the improper probability masses\n", "\n", "@memo\n", "def f_jit_autonorm[_c: Coin]():\n", " observer: given(c in Coin, wpp=biased_coin_improper_pmf(c))\n", " return Pr[observer.c == _c]\n", "\n", "res = f_jit_autonorm(print_table=True)" ] }, { "cell_type": "markdown", "id": "86e3fdbe-77d9-4f4c-a5d5-53ea21b81c1c", "metadata": {}, "source": [ "### Output options\n", "\n", "`memo` can package the results in a variety of ways. By default, a\n", "`@memo` returns a JAX array." ] }, { "cell_type": "code", "execution_count": null, "id": "cc189e88", "metadata": {}, "outputs": [], "source": [ "f_jit()" ] }, { "cell_type": "markdown", "id": "1dab1f9c-1afd-4b55-ac71-6bdd9e262838", "metadata": {}, "source": [ "It is possible to additionally have `@memo` package the data in a 2D\n", "[`pandas`](https://pandas.pydata.org/docs/)\n", "[DataFrame](https://pandas.pydata.org/docs/reference/frame.html)" ] }, { "cell_type": "code", "execution_count": null, "id": "f928173d", "metadata": {}, "outputs": [], "source": [ "df = f_jit(return_pandas=True).aux.pandas\n", "print(\"DataFrame:\")\n", "print(df)\n", "print(\"\\nsliced:\")\n", "print(df.loc[df[\"_c\"] == \"HEADS\"])" ] }, { "cell_type": "markdown", "id": "aca8d55c-915f-4051-a103-8a78a4c8fbb3", "metadata": {}, "source": [ "And as an N-dimensional [`xarray`](https://docs.xarray.dev/en/stable/)\n", "with named axes and named indexes." ] }, { "cell_type": "code", "execution_count": null, "id": "08506514", "metadata": {}, "outputs": [], "source": [ "xa = f_jit(return_xarray=True).aux.xarray\n", "xa\n", "xa.loc[\"HEADS\"]" ] }, { "cell_type": "markdown", "id": "260ec22f-0c09-4fd9-b502-c9efc05136f2", "metadata": {}, "source": [ "These are not mutually exclusive." ] }, { "cell_type": "code", "execution_count": null, "id": "83288f85", "metadata": {}, "outputs": [], "source": [ "res = f_jit(print_table=True, return_pandas=True, return_xarray=True)\n", "# JAX array\n", "res.data\n", "# Pandas DataFrame\n", "res.aux.pandas\n", "# xarray\n", "res.aux.xarray" ] }, { "cell_type": "markdown", "id": "406a2ddb-29de-4905-9f6d-5f1eecf96c5d", "metadata": {}, "source": [ "> **Conversion**\n", ">\n", "> `pandas` and `xarray` are *much* slower than JAX, and conversion of\n", "> types introduces additional overhead. It’s advisable to only convert\n", "> your data as a terminal step.\n", "\n", "### Querying specific values of bound variables\n", "\n", "In the process of building `memo` models, it’s often useful to examine a\n", "particular realization of a variable rather than enumerating over all\n", "possible values. For instance, we could have this `@memo` return the\n", "probability of $\\text{HEADS}$ alone by specifying `Pr[observer.c == 1]`\n", "rather than `== _c`." ] }, { "cell_type": "code", "execution_count": null, "id": "d9709b40", "metadata": {}, "outputs": [], "source": [ "@memo\n", "def f_query():\n", " observer: given(c in Coin, wpp=biased_coin_pmf(c))\n", " return Pr[observer.c == 1]\n", "\n", "f_query(print_table=True)" ] }, { "cell_type": "markdown", "id": "84ebc814-ebda-4766-b585-f39d0685a4e4", "metadata": {}, "source": [ "## Building on the basics\n", "\n", "Now that we’ve built a simple `@memo`, let’s extend it by tossing the\n", "coin multiple times.\n", "\n", "Let’s imagine that your teacher hands you a coin and says that you’ll\n", "get extra credit if it comes up $\\text{HEADS}$ at least once when you\n", "toss it two times." ] }, { "cell_type": "code", "execution_count": null, "id": "46c293d2", "metadata": {}, "outputs": [], "source": [ "@memo\n", "def flip_twice_v1():\n", " student: given(flip1 in Coin, wpp=1)\n", " student: given(flip2 in Coin, wpp=1)\n", " return Pr[student.flip1 + student.flip2 >= 1]\n", "\n", "flip_twice_v1()" ] }, { "cell_type": "markdown", "id": "be850375-8436-4478-9b43-7adb68402109", "metadata": {}, "source": [ "But what if we want to flip the coin 10 or 1000 times? The approach of\n", "adding another `given` statement would be inefficient to scale.\n", "Fortunately, we can construct **product spaces** to handle this\n", "efficiently.\n", "\n", "### Product spaces\n", "\n", "Here we make a product space[1] of two flips of the coin.\n", "\n", "[1] For background on product spaces, see Michael Betancourt’s chapters,\n", "\n", "- [Product\n", " Spaces](https://betanalpha.github.io/assets/chapters_html/product_spaces.html)\n", "- [Probability Theory on Product\n", " Spaces](https://betanalpha.github.io/assets/case_studies/probability_on_product_spaces.html)" ] }, { "cell_type": "code", "execution_count": null, "id": "6702db82", "metadata": {}, "outputs": [], "source": [ "from memo import domain as product\n", "\n", "SampleSpaceTwoFlips = product(\n", " f1=len(Coin),\n", " f2=len(Coin),\n", ")" ] }, { "cell_type": "markdown", "id": "ee1ccb56-c819-422f-8a73-ad1961fd2a01", "metadata": {}, "source": [ "The result, `SampleSpaceTwoFlips`, is the cross product of the sample\n", "space of the first flip and that of the second flip:\n", "\n", "$$\n", "F_1 \\times F_2 = \\{ (T,T), (T,H), (H,T), (H,H) \\}\n", "$$" ] }, { "cell_type": "code", "execution_count": null, "id": "e5c11256", "metadata": {}, "outputs": [], "source": [ "for i in range(len(SampleSpaceTwoFlips)):\n", " print(SampleSpaceTwoFlips._tuple(i))" ] }, { "cell_type": "markdown", "id": "fd70a2c6-3c1b-4a7e-9d71-37c9c739f612", "metadata": {}, "source": [ "Evaluating `SampleSpaceTwoFlips` itself just returns the indices\n", "corresponding to these tuples." ] }, { "cell_type": "code", "execution_count": null, "id": "62ceb6a6", "metadata": {}, "outputs": [], "source": [ "SampleSpaceTwoFlips" ] }, { "cell_type": "markdown", "id": "3b4aca2b-9115-4889-8c90-364c7393a9c9", "metadata": {}, "source": [ "But you can access the underlying information in various ways:" ] }, { "cell_type": "code", "execution_count": null, "id": "3f742597", "metadata": {}, "outputs": [], "source": [ "SampleSpaceTwoFlips._tuple(2)\n", "SampleSpaceTwoFlips.f1(2)\n", "SampleSpaceTwoFlips.f2(2)" ] }, { "cell_type": "markdown", "id": "b400e703-09af-4bc9-8b22-74a479decd8a", "metadata": {}, "source": [ "Again, mind the sharp bits." ] }, { "cell_type": "code", "execution_count": null, "id": "379c8391", "metadata": {}, "outputs": [], "source": [ "SampleSpaceTwoFlips._tuple(100)" ] }, { "cell_type": "markdown", "id": "564f4d34-6db4-441b-bf69-ab3d941f2872", "metadata": {}, "source": [ "We can now enumerate over all the events that can occur (where an event\n", "is the sequences of outcomes that results from flipping the coin twice):\n", "`given(s in SampleSpaceTwoFlips, wpp=1)` (remember that\n", "SampleSpaceTwoFlips evaluates to a list of integers, `[0, 1, 2, 3]`). To\n", "help the code tidy, we can define a `@jax.jit` function to sum the\n", "tuple." ] }, { "cell_type": "code", "execution_count": null, "id": "6168ca8c", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def sumflips(s):\n", " return SampleSpaceTwoFlips.f1(s) + SampleSpaceTwoFlips.f2(s)\n", "\n", "@memo\n", "def flip_twice():\n", " student: given(s in SampleSpaceTwoFlips, wpp=1)\n", " return Pr[sumflips(student.s) >= 1]\n", "\n", "flip_twice()" ] }, { "cell_type": "markdown", "id": "3e5b2faa-d25a-456d-9f0b-4a1200deec5c", "metadata": {}, "source": [ "Extending this to 10 flips is now straight forward. We simply define the\n", "sample space (now using dict comprehension to make\n", "`{\"f1\": 2, ..., \"f10\": 2}`) and dict unpacking (`**dict()`) to pass the\n", "contents to `product()` as keyword arguments. The result is\n", "`len(SampleSpace) == 1024`, which is the number of combinations that we\n", "expect ($2^10$).\n", "\n", "We also see that each tuple, which represents a sequence of 10 flips,\n", "has the expected size (`len(SampleSpace._tuple(0)) == 10`).\n", "\n", "Finally, we define a `@jax.jit` to sum this tuple. Here, we need to\n", "convert the tuple into a JAX array in order to sum it.\n", "\n", "Of course, your teacher would just be giving extra credit away if you\n", "had 10 flips to get a single head, let’s now say that you need between 4\n", "and 6 $\\text{HEADS}$ to win." ] }, { "cell_type": "code", "execution_count": null, "id": "b6e69060", "metadata": {}, "outputs": [], "source": [ "nflips = 10\n", "\n", "SampleSpace = product(**{f\"f{i}\": len(Coin) for i in range(1, nflips + 1)})\n", "\n", "@jax.jit\n", "def sumseq(s):\n", " return jnp.sum(jnp.array([SampleSpace._tuple(s)]))\n", "\n", "@memo\n", "def flip_n():\n", " student: given(s in SampleSpace, wpp=1)\n", " return Pr[sumseq(student.s) >= 4 and sumseq(student.s) <= 6]\n", "\n", "flip_n()" ] }, { "cell_type": "markdown", "id": "0545cd82-9d8a-47fc-b0e2-35fb8e309c47", "metadata": {}, "source": [ "Looks like your teacher is still quite generous!\n", "\n", "> **EXERCISE**\n", ">\n", "> To test your understanding, make sure you can calculate this value.\n", ">\n", "> > **ANSWER**\n", "> >\n", "> > Given on the webpage\n", "\n", "## Indexing\n", "\n", "Let’s imagine you have a deceptive teacher. After the second toss, she\n", "replaces the fair coin with a trick coin that only has a 10% chance of\n", "coming up $\\text{HEADS}$.\n", "\n", "We can calculate how much this will hurt your changes by specifying the\n", "probability mass on each flip, and then use those to get the probability\n", "mass on each sequence (which is what we need to pass as `wpp`).\n", "\n", "Let’s start by visualizing the distribution of $\\text{HEADS}$ in a\n", "sequences of 10 flips of a fair coin." ] }, { "cell_type": "code", "execution_count": null, "id": "35379f05", "metadata": {}, "outputs": [], "source": [ "from matplotlib import pyplot as plt\n", "\n", "nflips = 10\n", "\n", "SampleSpace = product(**{f\"f{i}\": len(Coin) for i in range(1, nflips + 1)})\n", "\n", "### repackage into a JAX array, which we'll use for indexing\n", "sample_space = jnp.array([SampleSpace._tuple(i) for i in range(len(SampleSpace))])\n", "\n", "fig, ax = plt.subplots()\n", "_ = ax.hist(sample_space.sum(axis=1).tolist(), color=\"blue\", alpha=0.3)\n", "ax.axvline(4, color=\"red\")\n", "ax.axvline(7, color=\"red\")\n", "ax.set_xticks((jnp.arange(nflips + 1) + 0.5).tolist())\n", "_ = ax.set_ylabel(\"Number of sequences\")\n", "_ = ax.set_xlabel(\"Number of HEADS\")\n", "_ = ax.set_xticklabels(range(nflips + 1))\n", "ax.set_xlim((0, nflips + 1))\n", "\n", "(nheads, nsequences) = jnp.unique(sample_space.sum(axis=1), return_counts=True)\n", "for (h, s) in zip(nheads.tolist(), nsequences.tolist()):\n", " print(f\"#HEADS: {h}, #sequences: {s}\")" ] }, { "cell_type": "markdown", "id": "256ac292-553a-45f0-9a11-fda0af89735e", "metadata": {}, "source": [ "Let’s assign the probability mass for each flip in order to calculate\n", "the probably of each sequence." ] }, { "cell_type": "code", "execution_count": null, "id": "4fe3d3ca", "metadata": {}, "outputs": [], "source": [ "### assign the probability mass for each flip\n", "flip_probs_biased = (\n", " jnp.full_like(\n", " sample_space, jnp.nan, dtype=float\n", " ) ### init a new array with nans (for safety)\n", " .at[jnp.where(sample_space == 1)] ### for every HEADS\n", " .set(0.1) ### assign it prob 0.1\n", " .at[jnp.where(sample_space == 0)] ### for every TAILS\n", " .set(0.9) ### assign it prob 0.9\n", " .at[:, :2] ### for the first two tosses\n", " .set(0.5) ### use a fair coin\n", ")\n", "\n", "### let's make sure we didn't mess us out indexing in an obvious way\n", "assert not jnp.any(jnp.isnan(flip_probs_biased))\n", "\n", "### the probability of a sequence is the product of the individual flips\n", "sequence_probs_biased = flip_probs_biased.prod(axis=1)\n", "\n", "### let's make sure the sample space is a simplex\n", "assert jnp.isclose(sequence_probs_biased.sum(), 1.0)\n", "\n", "flip_probs_biased" ] }, { "cell_type": "markdown", "id": "5cdb4642-fb96-45e9-8011-a231d67b6ce7", "metadata": {}, "source": [ "Let’s compare the distributions" ] }, { "cell_type": "code", "execution_count": null, "id": "f5b9f073", "metadata": {}, "outputs": [], "source": [ "### assign the probability masses for the fair coin\n", "flip_probs_unbiased = jnp.full_like(sample_space, 0.5, dtype=float)\n", "\n", "sequence_probs_unbiased = flip_probs_unbiased.prod(axis=1)\n", "assert jnp.isclose(sequence_probs_unbiased.sum(), 1.0)\n", "\n", "\n", "def calc_probs(sequence_probs):\n", " probs = []\n", " for nheads in range(nflips + 1):\n", " index = jnp.where(sample_space.sum(axis=1) == nheads)\n", " probs.append(sequence_probs[index].sum().item())\n", " return probs\n", "\n", "\n", "fig, ax = plt.subplots()\n", "ax.bar(\n", " range(nflips + 1),\n", " calc_probs(sequence_probs_unbiased),\n", " facecolor=\"none\",\n", " edgecolor=\"black\",\n", " linewidth=3.0,\n", " alpha=0.2,\n", " label=\"Unbiased Game\",\n", ")\n", "ax.bar(\n", " range(nflips + 1),\n", " calc_probs(sequence_probs_biased),\n", " facecolor=\"blue\",\n", " alpha=0.2,\n", " label=\"Deceptive Teacher\",\n", ")\n", "ax.axvline(3.5, color=\"red\")\n", "ax.axvline(6.5, color=\"red\")\n", "_ = ax.set_ylabel(\"$probability$\")\n", "_ = ax.set_xlabel(\"Number of HEADS\")\n", "_ = ax.set_xticks(range(nflips + 1))\n", "ax.legend()" ] }, { "cell_type": "markdown", "id": "f98d94be-70fd-4b99-a7dd-b95b3bdbbcd6", "metadata": {}, "source": [ "And now we can pass the probability mass we defined to `wpp`." ] }, { "cell_type": "code", "execution_count": null, "id": "36f2776f", "metadata": {}, "outputs": [], "source": [ "@jax.jit\n", "def probfn_biased(s):\n", " return sequence_probs_biased[s]\n", "\n", "@jax.jit\n", "def sumseq(s):\n", " return jnp.sum(jnp.array([*SampleSpace._tuple(s)]))\n", "\n", "@memo\n", "def flip_game():\n", " student: given(s in SampleSpace, wpp=probfn_biased(s))\n", " return Pr[sumseq(student.s) >= 4 and sumseq(student.s) <= 6]\n", "\n", "flip_game()" ] }, { "cell_type": "markdown", "id": "b55c8163-5b2a-4af3-b3ac-92d8cc3ccb78", "metadata": {}, "source": [ "How much did your teacher’s trickery affect your chances?\n", "\n", "## Exercise\n", "\n", "**A game of dice.** You have 3 die. The first dice has 4 sides, the\n", "second has 6 sides, and third has 8 sides. You roll your three die\n", "twice. The *d4* is fair. The *d6* is loaded such that there’s a 50%\n", "chance that it lands on 6, and a 10% chance that it lands on each other\n", "number. The *d8* is fair for the first roll, then it’s dropped and chips\n", "in such a way that it’s 3x more likely to land on an even number than an\n", "odd number (all evens are equally likely, and all odds are equally\n", "likely).\n", "\n", "1. Write a `@memo` that returns a JAX array with the probabilities of\n", " every possible combination of the die in this game (*i.e.* across\n", " all rolls).\n", "\n", "2. Write a `@memo` that returns the probability that the sum of the\n", " three die on roll 2 is greater than or equal to the sum of the three\n", " die on roll 1." ] }, { "cell_type": "markdown", "id": "caf15272-3959-47f3-942b-2d36583a0578", "metadata": { "raw_mimetype": "text/html", "vscode": { "languageId": "html" } }, "source": [ "
\n", "HINT\n", "If you're having trouble getting started, try making yourself a simpler version of the problem. \n", "E.g. start with \n", "(i) one unbiased 6-sided dice rolled once, then \n", "(ii) two 6-six sided die rolled once, \n", "(iii) two 6-sided die rolled twice, \n", "(iv) two biased 6-sided die rolled twice, \n", "etc.\n", "
\n", "You might find that it's easier to build the model up in this fashion. \n", "

\n", "This is generally a good strategy for model building — start with the simplest thing and extended it one tiny piece at a time, checking that each piece is working the way you expect it to as you go.\n", "
" ] }, { "cell_type": "markdown", "id": "49fe2e28-0565-4a00-a6b7-dab961044bb0", "metadata": {}, "source": [ "------------------------------------------------------------------------" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.1" } }, "nbformat": 4, "nbformat_minor": 5 }