import jax
import jax.numpy as jnp
from memo import memo
from matplotlib import pyplot as plt
Recursive Beliefs
Thinking about thinking with the 2/3rds game
Inspired by: Nagel (1995), “Unraveling in guessing games: An experimental study”.
Imagine that a newspaper like the New York Times runs a contest. Readers can write in a guess of an integer from 0-100. The editor computes the average of the readers’ guesses. The winner is the reader whose guess is closest to 2/3rds of the average.
Before continuing, write down what you would guess!
After you’ve written down your guess for (i) the New York Times readers, write down your guess if you were playing this game with (ii) a group of your friends, and (iii) a group of economists.
Are you guesses for (i), (ii), and (iii) the same? Different?
This problem evokes recursive reasoning over a Theory of Mind. One way to make a guess is to reason about what other people would reason about other people’s reasoning.
The Nash equilibrium for this game is for all readers to guess 0 or 1. However, in practice most people guess a number much higher than that. Why?
One reason a reader might choose a higher number is that she expects others to be less rational, and so they will on average pick a number greater than 2. It could also be that a reader thinks that others are rational enough to choose 0 or 1, but those people think that other people are less rational. Or maybe the reader was just not able to find the Nash equilibrium herself.
In this tutorial, we will model a “level-k” strategy for this game, and then fit the model’s predictions to real data collected by the New York Times.
Let’s bring in some of the standard imports.
Readers can choose an integer between 0 and 100, inclusive. Let’s define that sample space N.
= jnp.arange(100 + 1)
N N.shape
(101,)
A good practice for writing probabilistic models is to start overly simple and then build up. We’ll start with a generative model of what choices the greater population of readers makes. Recall how we modeled dice rolls in the Generative Models 1 tutorial. Let’s start by building a generative model of a uniformly random draw from this sample space (like rolling a 101-sided die).
@memo
def d101[_n: N]():
in N, wpp=1)
observer: given(n return Pr[observer.n == _n]
d101()
Array([0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099, 0.00990099, 0.00990099, 0.00990099, 0.00990099,
0.00990099], dtype=float32)
Modeling agentic choice
Now we’re going to rewrite the program in the mentalistic grammar of memo
. Since we think that people are actively choosing a number, let’s replace the given()
statement with chooses()
. These have exactly the same effect (they alias to the same internal function). And we’ll update the function name and the agent to reflect what we’re modeling, renaming the observer
to population
and the model to population_uniform_choice
@memo
def population_uniform_choice[_n: N]():
in N, wpp=1)
population: chooses(n return Pr[population.n == _n]
### prove to ourselves that these models give the same output
jnp.array_equal(population_uniform_choice(), d101()).item()
True
Parametric priors
Let’s add a bit of complexity. Say that we think that people’s choice of numbers is Gaussian rather than uniform. We can add a prior by passing a probability function as wpp
(remember that the function needs to be wrapped with jax.jit
).
from jax.scipy.stats.norm import pdf as normpdf
= jax.jit(normpdf)
normpdf
@memo
def population_gaussian_choice[_n: N]():
population: chooses(in N,
n =normpdf(n, 50, 3)
wpp
)
return Pr[population.n == _n]
= population_gaussian_choice()
res = jnp.dot(N, res)
expectation
= plt.subplots()
fig, ax
ax.bar(N, res)range(0, 100 + 1, 10))
ax.set_xticks(="red", label=r"$\mathbf{E}[N] = " + f"{expectation:.2f}$")
ax.axvline(expectation, color= ax.set_xlabel("$n$")
_ = ax.set_ylabel("$P(n)$")
_
ax.legend()print(f"Expectation: {expectation}")
Expectation: 50.0
Try playing with the parameters of the prior. We know that Gaussian distributions are symmetric. When the distribution is centered on 50, the expectation of the choices is 50 (that makes sense). What happens if you use \mathcal{N}(\mu=5, \sigma=3)? What about \mu=0?
In this example, we are passing a PDF (probability density function) as wpp
. Memo will treat this as a PMF (probability mass function), normalizing the probability measure over the support.
For instance, \sum_{x \in \{0,2,4\}} \mathcal{N}(x \mid 0, 1) is, of course, not going to to form a proper probability measure (it does not sum to one):
= jnp.array([0, 2, 4])
X
0, 1)
normpdf(X,
print(f"Sum: {normpdf(X, 0, 1).sum()}")
Array([3.9894229e-01, 5.3990960e-02, 1.3383021e-04], dtype=float32)
Sum: 0.4530670940876007
But when passed as wpp
, memo
converts these values into a PMF.
@memo
def example[_x: X]():
a: chooses(in X,
x =normpdf(x, 0, 1)
wpp
)return Pr[a.x == _x]
example()
print(f"Sum: {example().sum()}")
Array([8.8053691e-01, 1.1916769e-01, 2.9538717e-04], dtype=float32)
Sum: 1.0
Query variables
In the definition of a @memo
, you can include “query variables” in square brackets, e.g. def f[_a: A, _b: B, _c: C]():
. memo
will enumerate over the values that these can take, returning an array with the same number of axes.
= (jnp.arange(2) + 1) * 100
A = (jnp.arange(3) + 1) * 10
B = jnp.arange(4) + 1
C
@memo
def f[_a: A, _b: B, _c: C](): ### Specify the query variables in the function definition
return _a + _b + _c ### Return the sum of the query variables
f()print(f"Shape of output array: {f().shape}")
Array([[[111, 112, 113, 114],
[121, 122, 123, 124],
[131, 132, 133, 134]],
[[211, 212, 213, 214],
[221, 222, 223, 224],
[231, 232, 233, 234]]], dtype=int32)
Shape of output array: (2, 3, 4)
The @memo
f()
above will enumerate over \{ (\_a, \_b, \_c) : \_a \in A, \_b \in B, \_c \in C \}. This causes @memo
to return an array with 3 axes, where each element is the value of the @memo
evaluated at the query variables jointly: f(_a, _b, _c)
. Thus, f()
outputs an array with 3 axes (corresponding to _a
, _b
, and _c
), where dimension 1 has a size of len(A)
(the number of element in A
, which is 2), dimension 2 has a size of len(B)
, and dimension 3 has a size of len(C)
.
This is easy to see with print_table=True
:
= f(print_table=True) _
+-------+-------+-------+------+
| _a: A | _b: B | _c: C | f |
+-------+-------+-------+------+
| 100 | 10 | 1 | 111 |
| 100 | 10 | 2 | 112 |
| 100 | 10 | 3 | 113 |
| 100 | 10 | 4 | 114 |
| 100 | 20 | 1 | 121 |
| 100 | 20 | 2 | 122 |
| 100 | 20 | 3 | 123 |
| 100 | 20 | 4 | 124 |
| 100 | 30 | 1 | 131 |
| 100 | 30 | 2 | 132 |
| 100 | 30 | 3 | 133 |
| 100 | 30 | 4 | 134 |
| 200 | 10 | 1 | 211 |
| 200 | 10 | 2 | 212 |
| 200 | 10 | 3 | 213 |
| 200 | 10 | 4 | 214 |
| 200 | 20 | 1 | 221 |
| 200 | 20 | 2 | 222 |
| 200 | 20 | 3 | 223 |
| 200 | 20 | 4 | 224 |
| 200 | 30 | 1 | 231 |
| 200 | 30 | 2 | 232 |
| 200 | 30 | 3 | 233 |
| 200 | 30 | 4 | 234 |
+-------+-------+-------+------+
Details and examples
In the process of building up a model, you might start with something like this, which doesn’t have any query variables.
= jnp.array([40, 45, 55, 60])
X @memo
def example1():
in X, wpp=1)
agent: chooses(x return Pr[agent.x]
= example1()
res print(f"Shape: {res.shape}")
res
Shape: ()
Array(50., dtype=float32)
What’s being returned by example1()
?
Since there are no query variables, the function will always return the same value. In this case, it’s the expectation of the agent’s choice. To see the probability of each choice x
, enumerate over \{\_x \in X\} and return the probability for each \_x.
= jnp.array([40, 45, 55, 60])
X @memo
1def example2[_x: X]():
in X, wpp=1)
agent: chooses(x 2return Pr[agent.x == _x]
= example2()
res print(f"Mean: {res.mean()}")
print(f"Expectation: {jnp.dot(res, X)}")
print(f"Shape: {res.shape}")
res
- 1
-
Add
[_x: X]
to the function definition - 2
-
Add
== _x
to the return statement
Mean: 0.25
Expectation: 50.0
Shape: (4,)
Array([0.25, 0.25, 0.25, 0.25], dtype=float32)
Adding [_x: X]
to the function definition makes memo
add a new dimension to the output data. Adding == _x
to the return statement makes the function return a value (in this case, the probability that the agent chooses _x
) for each value of _x
.
Notice that def example1():
returns Array(50.)
, which has the shape ()
(i.e. no dimensions, it is a single number). However, def example2[_x: X]():
returns Array([0.25, 0.25, 0.25, 0.25])
, which has the shape (4,)
(i.e. an array with one axis and 4 elements).
In this way, you can evaluate a @memo
for multiple query variables jointly:
from enum import IntEnum
= jnp.arange(3) * -1
A
class B(IntEnum):
= 0
VAL1 = 1
VAL2
= jnp.arange(4) + 100
C
@memo
def example3[_a: A, _b: B, _c: C]():
in A, wpp=1)
agent: chooses(a in B, wpp=1)
agent: chooses(b in C, wpp=1)
agent: chooses(c return Pr[agent.a == _a, agent.b == _b, agent.c == _c]
= example3(print_table=True)
res print(f"Shape: {res.shape}")
res
+-------+-------+-------+---------------------+
| _a: A | _b: B | _c: C | example3 |
+-------+-------+-------+---------------------+
| 0 | VAL1 | 100 | 0.0416666679084301 |
| 0 | VAL1 | 101 | 0.0416666679084301 |
| 0 | VAL1 | 102 | 0.0416666679084301 |
| 0 | VAL1 | 103 | 0.0416666679084301 |
| 0 | VAL2 | 100 | 0.0416666679084301 |
| 0 | VAL2 | 101 | 0.0416666679084301 |
| 0 | VAL2 | 102 | 0.0416666679084301 |
| 0 | VAL2 | 103 | 0.0416666679084301 |
| -1 | VAL1 | 100 | 0.0416666679084301 |
| -1 | VAL1 | 101 | 0.0416666679084301 |
| -1 | VAL1 | 102 | 0.0416666679084301 |
| -1 | VAL1 | 103 | 0.0416666679084301 |
| -1 | VAL2 | 100 | 0.0416666679084301 |
| -1 | VAL2 | 101 | 0.0416666679084301 |
| -1 | VAL2 | 102 | 0.0416666679084301 |
| -1 | VAL2 | 103 | 0.0416666679084301 |
| -2 | VAL1 | 100 | 0.0416666679084301 |
| -2 | VAL1 | 101 | 0.0416666679084301 |
| -2 | VAL1 | 102 | 0.0416666679084301 |
| -2 | VAL1 | 103 | 0.0416666679084301 |
| -2 | VAL2 | 100 | 0.0416666679084301 |
| -2 | VAL2 | 101 | 0.0416666679084301 |
| -2 | VAL2 | 102 | 0.0416666679084301 |
| -2 | VAL2 | 103 | 0.0416666679084301 |
+-------+-------+-------+---------------------+
Shape: (3, 2, 4)
Array([[[0.04166667, 0.04166667, 0.04166667, 0.04166667],
[0.04166667, 0.04166667, 0.04166667, 0.04166667]],
[[0.04166667, 0.04166667, 0.04166667, 0.04166667],
[0.04166667, 0.04166667, 0.04166667, 0.04166667]],
[[0.04166667, 0.04166667, 0.04166667, 0.04166667],
[0.04166667, 0.04166667, 0.04166667, 0.04166667]]], dtype=float32)
Indexing output
Remember that you can convert these into an pandas DataFrame or an xarray for easy indexing, e.g. xa.loc[-2, "VAL2", 103]
:
= example3(return_xarray=True)
res = res.data
data = res.aux.xarray
xa print(f"JAX array: {data[2, 1, 3]}")
print(f"xarray: {xa.loc[-2, "VAL2", 103].item()}")
JAX array: 0.0416666679084301
xarray: 0.0416666679084301
Empty axes
If you include a query variable in the function definition, but return a value that does not depend on the query variable, memo
will add an empty axis to the output array:
@memo
def example4[_a: A, _b: B, _c: C]():
in A, wpp=1)
agent: chooses(a in B, wpp=1)
agent: chooses(b in C, wpp=1)
agent: chooses(c # return Pr[agent.a == _a, agent.b == _b, agent.c == _c]
return Pr[agent.b == _b]
= example4()
res print(f"Shape: {res.shape}")
res
Shape: (1, 2, 1)
Array([[[0.5],
[0.5]]], dtype=float32)
Identity of query variables
Note that multiple query variables can derive from the same sample space, as in \{\_x \in Z\} and \{\_y \in Z\}.
= jnp.arange(-10, 10+1)
C = jnp.array([-1, 0, 1])
Z
@memo
def example5[_x: Z, _y: Z]():
in C, wpp=1)
agent1: chooses(c in C, wpp=1)
agent2: chooses(c return Pr[agent1.c < _x, agent2.c > _y]
= example5()
res print(f"Shape: {res.shape}")
res
Shape: (3, 3)
Array([[0.22448967, 0.20408155, 0.18367343],
[0.24943294, 0.22675724, 0.20408155],
[0.2743762 , 0.24943294, 0.22448967]], dtype=float32)
Using a @memo
as a submodel
Returning to our example, we have a simple generative model of the populations’ choice. Let’s now model a reader thinking about the population. We can do this with the memo
syntax thinks[]
. In the model below, we build a new @memo
, reader_thinks()
with a new mental frame, reader
.
@memo
def population_gaussian_choice_prob[_n: N]():
in N, wpp=normpdf(n, 50, 3))
population: chooses(n return Pr[population.n == _n]
@memo
def reader_thinks():
reader: thinks[
population: chooses(in N,
n =population_gaussian_choice_prob[n]()
wpp
)
]
return reader[ E[population.n] ]
reader_thinks()
Array(50., dtype=float32)
The reader_thinks()
model simulates what number the reader
thinks the population
will chooses, generating a probability each of the 101 numbers. Then, the @memo
returns the reader
‘s belief about the expected value of the populations’ choice (return reader[ E[population.n] ]
).
Pay attention to the grammar here: reader[ ... ]
means that we are querying information in the reader
’s mind. Specifically, the reader
’s expectation (E[ ... ]
) of the population
’s choice n
. Square brackets function as a way of entering a mental frame.1
We are now using E
(for Expectation) rather than Pr
(for Probability). This is just syntactic sugar. Under the hood, Pr
is an alias of E
.
memo
enumerates over discrete random variables by representing them as a sum of Bernoulli random variables. Thus, the probability of an event is given by the expectation of its Bernoulli random variables.
Notice how the @memo
“population_gaussian_choice_prob()
” is serving as a submodel in the @memo
“reader_thinks()
”. We’ll return to this idea shortly. For the moment, since there’s nothing complicated happening in population_gaussian_choice_prob()
, let’s simplify the model into a single @memo
:
@memo
def reader_thinks():
reader: thinks[in N, wpp=normpdf(n, 50, 3))
population: chooses(n
]
return reader[ E[population.n] ] ### query what the reader thinks about the population
reader_thinks()
Array(50., dtype=float32)
And just to make the next steps as clear as possible, let’s return to a uniform prior over the reader’s belief about the population’s choice:
@memo
def reader_thinks():
reader: thinks[in N, wpp=1) ### back to a uniform prior
population: chooses(n
]
return reader[ E[population.n] ] ### query what the reader thinks about the population
reader_thinks()
Array(50., dtype=float32)
Encapsulation in mental frames
Note that the population
frame only exists in the reader
’s mind, i.e. internal to the reader
frame. Trying to access the population
frame outside of the reader
frame raises an error.
@memo
def psychic_reader():
reader: thinks[in N, wpp=1)
population: chooses(n
]# return reader[ E[population.n] ] ### query what the reader thinks about the population
#########
### attempt to access information about the
### population outside of the reader's mind:
#########
return E[population.n]
psychic_reader()
memo.core.MemoError: Unknown choice population.n
file: "3754222123.py", line 11, in @memo psychic_reader
return E[population.n]
^
hint: Did you perhaps misspell n? population is not yet aware of any
choice called n. Or, did you forget to call
population.chooses(n ...) or population.knows(n) earlier in
the memo?
ctxt: This error was encountered in the frame of population. In
that frame, population is currently modeling the following 0
choices: .
info: You are using memo 1.1.0, JAX 0.5.2, Python 3.13.2 on Darwin.
memo
forces us to specify that what we are modeling is what the reader thinks about the population, and that this is different than the actual population (which we are not currently modeling).
Choice based on mental simulation
Let’s model the reader actually playing the game (in the most simple way possible). Rather than the reader estimating the average of the population’s choice, we’ll have the reader estimate 2/3rds of the average.
@memo
def reader_thinks():
reader: thinks[in N, wpp=1)
population: chooses(n
]
return reader[ (2/3)*E[population.n] ]
reader_thinks()
Array(33.333336, dtype=float32)
Now that the reader is estimating 2/3rds of the population’s average choice, we can model the reader’s choice. We’ll start by modeling the reader as perfectly rational. I.e. the reader will always pick the value that’s closest to 2/3rds of what the reader thinks the population’s average choice will be.
To accomplish this perfectly rational behavior, replace wpp
in chooses(...)
with to_maximize
. With the to_maximize
keyword, chooses()
performs an argmax, assigning a probability of 1 to the choice with the greatest to_maximize
value.
To have the reader pick the choice closest to 2/3rds the average, we find the absolute difference between 2/3rds the average and every possible choice n
: abs( (2/3)*E[population.n] - n )
, and pass the negative absolute error as to_maximize=-abs(...)
, since we want the reader to pick the n
with the smallest absolute difference.
@memo
def reader_choice[_n: N]():
reader: thinks[in N, wpp=1)
population: chooses(n
]
in N, to_maximize=-abs( (2/3)*E[population.n] - n ))
reader: chooses(n return Pr[reader.n == _n]
jnp.where(reader_choice())
(Array([33], dtype=int32),)
The model returns the probability of the reader making each choice. Since we’re using to_maximize
, the reader will always choose the n
with the smallest absolute error, so evaluating reader_choice()
produces an array of len(N)
elements, with all of the probability mass on the n
closest to (2/3)*E[population.n]
.
jnp.where(...)
returns the indices at which there non-zero values. Is this the index you expect?
Under a uniform prior, the expected value of the population choice is 50
. 2/3rds of that is 33.3333. The closest n
that the agent can choose is 33, which is index 33 in N
. It seems this model does what we expect it to, but it is still far from being a model of what we think readers are actually doing.
Recursive reasoning
A critical missing component is the reader thinking about a population of other readers who are performing the same type of reasoning, trying to guess 2/3rds of the average. How can we model this recursive reasoning?
We previously saw how a @memo
can be used as a probability function in the chooses()
statement of another @memo
. Recall:
@memo
def population_gaussian_choice_prob[_n: N]():
in N, wpp=normpdf(n, 50, 3))
population: chooses(n return Pr[population.n == _n]
@memo
def reader_thinks():
reader: thinks[in N, wpp=population_gaussian_choice_prob[n]())
population: chooses(n
]return reader[ E[population.n] ]
We can do something similar now. Since we (as a first approximation) are assuming that the reader thinks other readers engage the same type of reasoning when playing this game, we can use the same @memo
as a probability model for itself.
@memo
1def reader_choice_k[_n: N](k):
reader: thinks[
population: chooses(in N,
n 2=reader_choice_k[n](k-1) if k > 0 else 1
wpp
)
]in N, to_maximize=-abs((2/3)*E[population.n] - n))
reader: chooses(n return Pr[reader.n == _n]
31)) jnp.where(reader_choice_k(
- 1
-
Pass a k level to the
@memo
(reader_choice_k
) as an argument - 2
-
Evaluate
reader_choice_k(k-1)
to get a probability measure for eachn
- 3
- Evaluate the model for a level-k 1 reader.
(Array([22], dtype=int32),)
The reader_choice_k
model makes two changes to the reader_choice
model. First, the function definition is modified: An integer k is passed to the model as an argument. Notice that this is different than a query variable (e.g. [_n: N]
). Whereas _n
is bound to the outer frame (i.e. the frame of the reader_choice_k
function), the k
argument is unbound and can be accessed anywhere in the @memo
. Second, the probability that (in the reader’s mind) the population will choose n
is given by reader_choice_k[n](k-1) if k > 0 else 1
. Let’s break this down.
reader_choice_k
is defined with [_n: N]
, which specifies that it should enumerate over \_n \in N. When reader_choice_k
is called in population: chooses(n in N, wpp=...)
, the n
being considered is passed reader_choice_k
as _n
: reader_choice_k[n]
. This is how the probability for the correct value of n
is passed as wpp
. The function also needs to be called, using (...)
. If there are no arguments, then nothing goes into the parentheses (as was the case with wpp=population_gaussian_choice_prob[n]()
above). In this case, we need to pass a value for k
. But what value?
We pass the submodel a k level one less than k level of the model that calls it. Why? Here’s two way to think about it: - If k
does not decrease, then then recursion never terminates (you can try this if you want; replace k-1
with k
and you’ll receive a RecursionError: maximum recursion depth exceeded
). - To model a certain level of strategic sophistication, your cognitive system must have at least that much complexity. To represent level-k reasoning, a cognitive system must be at least level k+1.
Finally, there is a boundary condition that says “if k
is 0, then treat the population as making a uniform choice”:
-1) if k > 0 else 1 reader_choice_k[n](k
First run this @memo
as is. @memo(debug_trace=True)
will produce an execution trace that shows how memo recurses through the k
values.
Then try changing k-1
to k
. What happens and why?
@memo(debug_trace=True)
def reader_choice_k_example[_n: N](k):
reader: thinks[
population: chooses(in N,
n ### change from k-1 to k
=reader_choice_k_example[n](k-1) if k > 0 else 1
wpp
)
]in N, to_maximize=-abs((2/3)*E[population.n] - n))
reader: chooses(n return Pr[reader.n == _n]
3) reader_choice_k_example(
--> reader_choice_k_example(3)
--> reader_choice_k_example(2)
--> reader_choice_k_example(1)
--> reader_choice_k_example(0)
<-- reader_choice_k_example(0) has shape (101,)
time = 0.041246 sec
<-- reader_choice_k_example(1) has shape (101,)
time = 0.041315 sec
<-- reader_choice_k_example(2) has shape (101,)
time = 0.041352 sec
<-- reader_choice_k_example(3) has shape (101,)
time = 0.041423 sec
Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)
What happens when we pass this model different k
values?
for k_ in range(12):
= reader_choice_k(k=k_)
res print(f"k = {k_:>2} -> guess = {jnp.where(res)[0].item():>2}")
k = 0 -> guess = 33
k = 1 -> guess = 22
k = 2 -> guess = 15
k = 3 -> guess = 10
k = 4 -> guess = 7
k = 5 -> guess = 5
k = 6 -> guess = 3
k = 7 -> guess = 2
k = 8 -> guess = 1
k = 9 -> guess = 1
k = 10 -> guess = 1
k = 11 -> guess = 1
At k=0
, the reader thinks that the population is making a uniform choice. Thus, the expectation is 50, and the reader guesses 33. At k=1
, the reader thinks that everyone else is reasoning like the k=0
model. Thus the reader guesses 2/3rds of 33. Etc.
Why does the guess never reach zero? Remember that the reader can only choose integers. When the reader thinks that everyone else will guess 1, the model finds the integer closest to (2/3)*E[1]
, which is 2/3. And 0.6667 is closer to 1 than to 0, so according to this model, a perfectly rational agent will always guess 1.
Try changing (2/3)
to something lower, like 0.49
. What happens then? How does that change the equilibrium? Does that affect the number of steps required for the model to converge to the equilibrium?
Approximately rational choice
Softmax
P(i) = \frac{e^{\beta z_i}}{\sum_{j} e^{\beta z_j} }
In practice we’re most often concerned with the relative probabilities of the options, so it is sufficient to use
P(i) \propto \exp(\beta z_i)
since the dominator cancels out (e.g. P(i=1)/P(i=2)).
Interactive exploration of the softmax function
Plotly code
def softmax(x, beta=1.0):
"""Compute softmax values with inverse temperature parameter beta"""
import numpy as np
= np.array(x) * beta
x = np.exp(x - np.max(x))
exp_x return exp_x / exp_x.sum()
def make_plotly_softmax_fig(input_values, labels):
import numpy as np
import plotly.graph_objects as go
# Create figure
= go.Figure()
fig
# Add traces for different beta values
= np.arange(0.1, 5.1, 0.1)
beta_values for beta in beta_values:
= softmax(input_values, beta)
probs
# Add bars for input values (scaled for visibility)
= fig.add_trace(
_
go.Bar(=False,
visible=labels,
x=input_values, # Scale down inputs for better visualization
y="Input Values",
name='rgba(76, 175, 80, 0.7)',
marker_color=True,
showlegend=[f"Input: {v:.1f}" for v in input_values],
text='outside'
textposition
)
)
# Add bars for softmax probabilities
= fig.add_trace(
_
go.Bar(=False,
visible=labels,
x=probs,
y="Probability of Choice",
name='rgba(33, 150, 243, 0.7)',
marker_color=True,
showlegend=[f'{p:.3f}' for p in probs],
text='outside'
textposition
)
)
# Make first pair of traces visible
0].visible = True
fig.data[1].visible = True
fig.data[
# Create and add slider
= []
steps for i in range(0, len(fig.data), 2): # Step by 2 since we have pairs of traces
= beta_values[i//2]
beta = dict(
step ="update",
method=[{"visible": [False] * len(fig.data)},
args"title": f"Softmax Parameter (β = {beta:.1f})"}],
{=f"{beta:.1f}",
label
)"args"][0]["visible"][i] = True # Show input bars
step["args"][0]["visible"][i+1] = True # Show probability bars
step[
steps.append(step)
= [dict(
sliders =0,
active# currentvalue={"prefix": "β = "},
={
currentvalue"prefix": "beta = ",
# "font": {"size": 16},
},={"t": 50},
pad=steps
steps
)]
# Update layout
= fig.update_layout(
_ ="Softmax Function Demonstration",
title=sliders,
sliders='group',
barmode=dict(
yaxisrange=[-1, 1.2], # Set y-axis range to accommodate text labels
="Value / Probability"
title
),=True,
showlegend=dict(
legend="top",
yanchor=0.99,
y="right",
xanchor=0.99
x
)
)
fig.show()
In this plot, the options and corresponding values are A=1
, B=0.4
, and C=0
(shown in green). The probability of choosing A
, B
or C
is shown in blue. As \beta gets larger, it becomes increasing probable that the option with the highest value (in this case, A
) will be selected over the other options. In the limit of \beta approaching infinity, the softmax becomes the argmax, meaning that the option with the highest value will always be chosen, regardless of how infinitesimally smaller the value of the next best choice is.
= jnp.array([1.0, 0.4, 0.0])
input_values_ = ['Option A', 'Option B', 'Option C']
labels_ make_plotly_softmax_fig(input_values_, labels_)
The softmax function works with any real-valued input. The input values can be greater than 1, less than zero, etc.
Plotly code
= -1 * jnp.array([1.0, 0.4, 0.0])
input_values_ = ['Option A', 'Option B', 'Option C']
labels_ make_plotly_softmax_fig(input_values_, labels_)
@memo
def reader_choice_k_soft[_n: N](k, beta=1):
reader: thinks[
population: chooses(in N,
n =reader_choice_k_soft[n](k-1, beta=beta) if k > 0 else 1
wpp
)
]in N, wpp=exp(beta * -abs((2/3)*E[population.n] - n)))
reader: chooses(n return Pr[reader.n == _n]
Exploration softmax parameter in the game
For k=5 and \beta = 1.0,
= 5
k_ = 1.0
beta_ = plt.subplots()
fig, ax =beta_), alpha=0.2, color='r')
ax.bar(N, reader_choice_k_soft(k_, beta=beta_), color='r', alpha=0.8, label=f"k={k_}")
ax.plot(N, reader_choice_k_soft(k_, beta0, 0.5))
ax.set_ylim((
jnp.where(reader_choice_k(k_))=beta_))
jnp.argmax(reader_choice_k_soft(k_, beta=beta_), N) jnp.dot(reader_choice_k_soft(k_, beta
Array(4, dtype=int32)
Array(4.416822, dtype=float32)
For k \in \{ 0, 1, 2, 5, 7 \} and \beta = 1.0,
= plt.subplots()
fig, ax = [0, 1, 2, 5, 7]
k_vals = 1.0
beta_ for k_ in k_vals:
=beta_), alpha=0.8, label=f"k={k_}")
ax.plot(N, reader_choice_k_soft(k_, beta0, 0.5))
ax.set_ylim(( ax.legend()
For k = 5 and \beta \in \{ 0.1, 1.0, 3.0 \},
= plt.subplots()
fig, ax = 5
k_ = [0.1, 1.0, 3.0]
beta_values for beta_ in beta_values:
=beta_), alpha=0.8, label=f"beta={beta_}")
ax.plot(N, reader_choice_k_soft(k_, beta0, 0.5))
ax.set_ylim(( ax.legend()
Interactive plot of the effect of \beta on P(N{=n}; k{=}5)
Plotly code
def make_reader_choice_plot(N, k, beta_values):
import numpy as np
import plotly.graph_objects as go
# Create figure
= go.Figure()
fig
# Add traces for different beta values
for beta in beta_values:
= reader_choice_k_soft(k, beta=beta)
probs
# Add bars for probability distribution
= fig.add_trace(
_
go.Bar(=False,
visible=N,
x=probs,
y=f'k={k}, β={beta}',
name=True,
showlegend=[f'{p:.3f}' for p in probs],
text='outside'
textposition
)
)
# Make first trace visible
0].visible = True
fig.data[
# Create slider steps
= []
steps for beta_idx, beta in enumerate(beta_values):
= dict(
step ="update",
method=[{"visible": [False] * len(fig.data)},
args"title": f"Reader Choice Distribution (k={k}, β={beta})"}],
{=f"{beta:.1f}"
label
)"args"][0]["visible"][beta_idx] = True
step[
steps.append(step)
= [dict(
sliders =0,
active={"prefix": "β = "},
currentvalue={"t": 50},
pad=steps
steps
)]
# Update layout
= fig.update_layout(
_ ="Reader Choice Distribution",
title="N",
xaxis_title="Probability",
yaxis_title=800,
width=500, # Reduced height since we only have one slider
height=True,
showlegend=dict(
legend="top",
yanchor=0.99,
y="right",
xanchor=0.99
x
),=sliders
sliders
)
# Set axis ranges
range=[-1, 101])
fig.update_xaxes(range=[0, 0.15]) # Adjusted based on typical probability ranges
fig.update_yaxes(
return fig
= 5
k = [0.1, 0.5, 1, 2, 3, 5]
beta_values = make_reader_choice_plot(N, k, beta_values)
fig fig.show()
%reset -f
import sys
import platform
import importlib.metadata
print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())
print("\nPackages:")
for name, version in sorted(
"Name"], dist.version) for dist in importlib.metadata.distributions()),
((dist.metadata[=lambda x: x[0].lower() # Sort case-insensitively
key
):print(f"{name}=={version}")
Python: 3.13.2 (main, Feb 5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64
Packages:
annotated-types==0.7.0
anyio==4.8.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
async-lru==2.0.4
attrs==25.1.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.17.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.8
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.1
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.5.2
jaxlib==0.5.1
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
memo-lang==1.1.0
mistune==3.1.2
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.3
opt_einsum==3.4.0
optype==0.9.1
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.241126
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.6
plotly==5.24.1
pre_commit==4.1.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pyparsing==3.2.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==3.3.0
pytz==2025.1
PyYAML==6.0.2
pyzmq==26.2.1
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.9.10
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.0
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==75.8.2
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.38
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.4.0
toml==0.10.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250204
typing_extensions==4.12.2
tzdata==2025.1
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.1.2
zipp==3.21.0
References
Footnotes
You might be wondering why
n
, which is bound to the reader’s mental model of thepopulation
’s collective mind is not also bracketed. In fact it is, the dot notation is just a convenience. This return statement could also be writtenreader[ E[population[n]] ]
.↩︎