import jax
import jax.numpy as jnp
from memo import memo
from memo import domain as product
from enum import IntEnum
from matplotlib import pyplot as plt
from jax.scipy.stats.beta import pdf as betapdf
"jax_platform_name", "cpu")
jax.config.update(
= jax.jit(betapdf)
betapdfjit
= jnp.linspace(0.0, 1.0, 10 + 1)
Proposal
class Response(IntEnum):
= 0
REJECT = 1
ACCEPT
= jnp.linspace(0, 1, 5)
Pref
@jax.jit
def val_money_proposer(proposal, response, pot):
return pot * (1 - proposal) * (response == Response.ACCEPT)
@jax.jit
def val_money_responder(proposal, response, pot):
return pot * proposal * (response == Response.ACCEPT)
@jax.jit
def val_di_proposer(proposal, response, pot):
return jnp.max(jnp.array([
val_money_responder(proposal, response, pot) - val_money_proposer(proposal, response, pot),
0
]))
@jax.jit
def val_ai_proposer(proposal, response, pot):
return jnp.max(jnp.array([
val_money_proposer(proposal, response, pot) - val_money_responder(proposal, response, pot),
0
]))
@jax.jit
def val_di_responder(proposal, response, pot):
return jnp.max(jnp.array([
val_money_proposer(proposal, response, pot) - val_money_responder(proposal, response, pot),
0
]))
@jax.jit
def val_ai_responder(proposal, response, pot):
return jnp.max(jnp.array([
val_money_responder(proposal, response, pot) - val_money_proposer(proposal, response, pot),
0
]))
@jax.jit
def preference_jointprior_pdfjit(pref_money, pref_dia, pref_aia):
return betapdf(pref_money, 1, 1) * betapdf(pref_dia, 1, 1) * betapdf(pref_aia, 1, 1)
Ultimatum Game
planing and inverse planning over an interated game
The Ultimatum Game is a common paradigm in behavioral economics used to study strategic decision making and cooperation.
In the Ultimatum Game, two players—a proposer and a responder—decide how to allocate a pot of money. The proposer offers a partition (e.g. “you get 20%, I get 80%”), and the responder either accepts or rejects it. If the responder accepts, the money is divided according to the proposal; if rejected, both players get nothing.
Actions in one-shot games
A simple model of responses
We’ll start by modeling a responder as a rational decision-maker whose choice is guided by three distinct preferences:
- Money (
pref_money
) - a preference for one’s own monetary payoff. - Disadvantageous Inequity Aversion (DIA)
pref_dia
- a preference against getting less than the other player. - Advantageous Inequity Aversion (AIA)
pref_aia
- a preference against getting more than the other player.
In this model, the responder’s utility is a weighted linear combination of the subjective utilities derived from the product of a simulated agent’s preference and an outcome:
U_{r}(a_r) = \omega_{money} \cdot V_{r,money} - \omega_{DIA} \cdot V_{r,DI} - \omega_{AIA} \cdot V_{r,AI}
where a_p is the action of the proposer (i.e. the proposal) and a_r is the action of the responder (i.e. the response). U_{r} is the subjective utility that a responder with preferences {(\omega_{money}, \omega_{DIA}, \omega_{AIA})} would derive from proposal a_p and response a_r. V_{r,\cdot} is the objective value of each outcome variable.
For example, if the proposer offers a 20/80 partition {(a_p = 0.2)} that the responder accepts (a_r=\text{ACCEPT}), then
\begin{align*} V_{r,money} =&~ 0.2 \cdot pot \\ V_{r,DI} =&~ (0.8 - 0.2) \cdot pot \\ V_{r,AI} =&~ 0 \end{align*}
For a given proposal and pot size, the responder chooses to accept or reject probabilistically, following a softmax decision rule:
P(a_r) \propto \exp(\beta_{r} \cdot U_{r}(a_r))
where \beta_{r} controls the responder’s sensitivity to subjective utility differences.
We’ll start with an overlysimplified model of a responder’s choice about whether to accept or reject a proposal.
To keep things simple, we’ll model the responder as
having equal probability of any set of preferences, and
thinking that the proposer has equal probability of making any proposal.
@memo
def response_model_v0[_proposal: Proposal, _response: Response](beta_responder=1.0, pot=10.0):
in Pref, pref_dia in Pref, pref_aia in Pref, wpp=1)
responder: given(pref_money
responder: thinks[in Proposal, wpp=1) ### oversimplification
proposer: chooses(proposal
]is _proposal
responder: observes [proposer.proposal] in Response, wpp=exp(beta_responder * (
responder: chooses(response * val_money_responder(proposer.proposal, response, pot)
pref_money - pref_dia * val_di_responder(proposer.proposal, response, pot)
- pref_aia * val_ai_responder(proposer.proposal, response, pot)
)))return Pr[ responder.response == _response ]
response_model_v0()
Array([[0.8833704 , 0.11662987],
[0.8313682 , 0.16863157],
[0.74316436, 0.2568355 ],
[0.5780735 , 0.4219264 ],
[0.3393453 , 0.66065454],
[0.16564558, 0.83435374],
[0.25683576, 0.7431642 ],
[0.33988732, 0.66011244],
[0.40628162, 0.59371847],
[0.45854715, 0.541453 ],
[0.5 , 0.4999999 ]], dtype=float32)
A simple model of proposals
Now, consider the proposer’s perspective. The proposer seeks to predict the responder’s response to choose an optimal proposal. This involves nested inferences: the proposer is reasoning about the responder’s preferences while anticipating the responder’s reasoning about their proposals.
Formally, the proposer calculates the expected utility of a proposal by integrating over the responder’s responses:
E[U_{p}(a_p)] = \sum_{a_r \in \{ ACCEPT, REJECT \}} U_{p}(a_p, a_r) P(a_p \mid a_r)
Then, the proposer chooses a proposal (a_p) according to a softmax rule:
P(a_p) \propto \exp(\beta_{p} \cdot E[U_{p}(a_p)])
We can now use the response_model_v0
model in an overlysimplified model of the proposer’s choice.
To keep things simple, we’ll model the proposer as having equal probability of any set of preferences.
@memo
def proposal_model_v0[_proposal: Proposal](beta_proposer=1.0, beta_responder=1.0, pot=10.0):
in Pref, pref_dia in Pref, pref_aia in Pref, wpp=1)
proposer: given(pref_money in Proposal, wpp=exp(beta_proposer * imagine[
proposer: chooses(proposal
responder: knows(proposal),
responder: chooses(in Response,
response =response_model_v0[
wpp
proposal,
response
](beta_responder, pot)),
E[ * val_money_proposer(proposal, responder.response, pot)
pref_money - pref_dia * val_di_proposer(proposal, responder.response, pot)
- pref_aia * val_ai_proposer(proposal, responder.response, pot)
]
]))return Pr[ proposer.proposal == _proposal ]
proposal_model_v0()
Array([0.05491067, 0.05632582, 0.06254137, 0.08616571, 0.17013934,
0.369276 , 0.09633861, 0.04321389, 0.02648459, 0.01922949,
0.01537458], dtype=float32)
With this these oversimplified outlines of how the two players make choices in a one-shot game to maximize their expected utilities, let’s now let’s model what the proposer infers the responder’s preferences are, given a sequences of rounds.
Preference Inference in Iterated Games
In an iterated Ultimatum Game, repeated interactions provide richer evidence to infer preferences. Observations from each round refine beliefs about preferences through Bayesian updating:
Define some imaginary data
= jnp.array([Proposal[4], Proposal[3], Proposal[6], Proposal[8]])
proposals = jnp.array([Response.ACCEPT, Response.REJECT, Response.ACCEPT, Response.REJECT])
responses
@jax.jit
def get_proposal(trial):
return proposals[trial]
@jax.jit
def get_response(trial):
return responses[trial]
Inference of responder’s preferences
@memo
def responder_pref_model[
_pref_money: Pref,
_pref_dia: Pref,
_pref_aia: Pref=1.0, beta_responder=1.0, pot=10.0):
](trial, beta_proposer"""
model of a rational observer's (e.g., the proposer's) inference about a responder's preferences
"""
### proposer's inference of responder's prefs based on responder's behavior in all previous trials
proposer: thinks[
responder: given(in Pref,
pref_money in Pref,
pref_dia in Pref,
pref_aia =responder_pref_model[
wpp
pref_money,
pref_dia,
pref_aia,
](- 1,
trial
beta_proposer,
beta_responder,
potif trial > 0 else preference_jointprior_pdfjit(pref_money, pref_dia, pref_aia)),
)
]
### here we not not model the proposer's choice, but rather use the proposal actually made
in Proposal, wpp=proposal == get_proposal(trial))
proposer: chooses(proposal
### proposer's prediction about about response the responder will make, given the
### preferences that the responder is inferred to have
proposer: thinks[
responder: knows(proposal),in Response, wpp=exp(beta_responder * (
responder: chooses(response * val_money_responder(proposal, response, pot)
pref_money - pref_dia * val_di_responder(proposal, response, pot)
- pref_aia * val_ai_responder(proposal, response, pot)
))),
]
### proposer updates beliefs about responder's preferences by observing what response
### the responder actually made in this trial
== get_response(trial)]
proposer: observes_that[responder.response
### return the posterior belief about the responder's preferences
proposer: knows(_pref_money, _pref_dia, _pref_aia)return E[proposer[Pr[
== _pref_money,
responder.pref_money == _pref_dia,
responder.pref_dia == _pref_aia
responder.pref_aia
]]]
0).shape responder_pref_model(
(5, 5, 5)
Let’s see what the proposer thinks the responder’s preferences are, given a sequences of rounds.
### some imaginary data
= jnp.array([Proposal[3], Proposal[3], Proposal[3], Proposal[3]])
proposals = jnp.array([Response.ACCEPT, Response.ACCEPT, Response.ACCEPT, Response.ACCEPT])
responses
@jax.jit
def get_proposal(t):
return proposals[t]
@jax.jit
def get_response(t):
return responses[t]
= 1.0
proposer_beta = 1.0
responder_beta = 10
potsize
print(f":: Proposer's inference of responder's prefs ::")
assert proposals.size == responses.size
for trial in range(responses.size):
print(
f"\nTrial {trial} -- ",
f"Proposal: {get_proposal(trial):0.2f} &",
f"Response: {Response(get_response(trial)).name}")
= responder_pref_model(
resx =trial,
trial=proposer_beta,
beta_proposer=responder_beta,
beta_responder=potsize,
pot=True,
return_aux=True
return_xarray
).aux.xarrayfor pref in resx.coords.dims:
= [pref_ for pref_ in resx.coords.dims if pref_ != pref]
marginalize_over = resx.sum(dim=marginalize_over)
marginal_dist = jnp.dot(Pref, marginal_dist.values)
expectation print(f"E[{pref.lstrip('_pref_')}] = {expectation:0.3f}")
:: Proposer's inference of responder's prefs ::
Trial 0 -- Proposal: 0.30 & Response: ACCEPT
E[money] = 0.642
E[dia] = 0.300
E[aia] = 0.500
Trial 1 -- Proposal: 0.30 & Response: ACCEPT
E[money] = 0.707
E[dia] = 0.215
E[aia] = 0.500
Trial 2 -- Proposal: 0.30 & Response: ACCEPT
E[money] = 0.749
E[dia] = 0.171
E[aia] = 0.500
Trial 3 -- Proposal: 0.30 & Response: ACCEPT
E[money] = 0.779
E[dia] = 0.144
E[aia] = 0.500
loop over trials
### some imaginary data
= jnp.array([Proposal[-3], Proposal[-3], Proposal[-3], Proposal[-3]])
proposals = jnp.array([Response.ACCEPT, Response.ACCEPT, Response.ACCEPT, Response.ACCEPT])
responses
@jax.jit
def get_proposal(t):
return proposals[t]
@jax.jit
def get_response(t):
return responses[t]
= 1.0
proposer_beta = 1.0
responder_beta = 10
potsize
print(f":: Proposer's inference of responder's prefs ::")
assert proposals.size == responses.size
for trial in range(responses.size):
print(
f"\nTrial {trial} -- ",
f"Proposal: {get_proposal(trial):0.2f} &",
f"Response: {Response(get_response(trial)).name}")
= responder_pref_model(
resx =trial,
trial=proposer_beta,
beta_proposer=responder_beta,
beta_responder=potsize,
pot=True,
return_aux=True
return_xarray
).aux.xarrayfor pref in resx.coords.dims:
= [pref_ for pref_ in resx.coords.dims if pref_ != pref]
marginalize_over = resx.sum(dim=marginalize_over)
marginal_dist = jnp.dot(Pref, marginal_dist.values)
expectation print(f"E[{pref.lstrip('_pref_')}] = {expectation:0.3f}")
:: Proposer's inference of responder's prefs ::
Trial 0 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.678
E[dia] = 0.500
E[aia] = 0.380
Trial 1 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.727
E[dia] = 0.500
E[aia] = 0.353
Trial 2 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.752
E[dia] = 0.500
E[aia] = 0.338
Trial 3 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.769
E[dia] = 0.500
E[aia] = 0.326
loop over trials
### some imaginary data
= jnp.array([Proposal[4], Proposal[3], Proposal[6], Proposal[8]])
proposals = jnp.array([Response.ACCEPT, Response.REJECT, Response.ACCEPT, Response.REJECT])
responses
@jax.jit
def get_proposal(t):
return proposals[t]
@jax.jit
def get_response(t):
return responses[t]
= 1.0
proposer_beta = 1.0
responder_beta = 10
potsize
print(f":: Proposer's inference of responder's prefs ::")
assert proposals.size == responses.size
for trial in range(responses.size):
print(
f"\nTrial {trial} -- ",
f"Proposal: {get_proposal(trial):0.2f} &",
f"Response: {Response(get_response(trial)).name}")
= responder_pref_model(
resx =trial,
trial=proposer_beta,
beta_proposer=responder_beta,
beta_responder=potsize,
pot=True,
return_aux=True
return_xarray
).aux.xarrayfor pref in resx.coords.dims:
= [pref_ for pref_ in resx.coords.dims if pref_ != pref]
marginalize_over = resx.sum(dim=marginalize_over)
marginal_dist = jnp.dot(Pref, marginal_dist.values)
expectation print(f"E[{pref.lstrip('_pref_')}] = {expectation:0.3f}")
:: Proposer's inference of responder's prefs ::
Trial 0 -- Proposal: 0.40 & Response: ACCEPT
E[money] = 0.626
E[dia] = 0.441
E[aia] = 0.500
Trial 1 -- Proposal: 0.30 & Response: REJECT
E[money] = 0.539
E[dia] = 0.608
E[aia] = 0.500
Trial 2 -- Proposal: 0.60 & Response: ACCEPT
E[money] = 0.631
E[dia] = 0.641
E[aia] = 0.467
Trial 3 -- Proposal: 0.80 & Response: REJECT
E[money] = 0.401
E[dia] = 0.564
E[aia] = 0.696
loop over trials
### some imaginary data
= jnp.array([Proposal[3], Proposal[4], Proposal[8], Proposal[5], Proposal[6]])
proposals = jnp.array([Response.REJECT, Response.REJECT, Response.ACCEPT, Response.REJECT, Response.ACCEPT])
responses
@jax.jit
def get_proposal(t):
return proposals[t]
@jax.jit
def get_response(t):
return responses[t]
= 1.0
proposer_beta = 1.0
responder_beta = 10
potsize
print(f":: Proposer's inference of responder's prefs ::")
assert proposals.size == responses.size
for trial in range(responses.size):
print(
f"\nTrial {trial} -- ",
f"Proposal: {get_proposal(trial):0.2f} &",
f"Response: {Response(get_response(trial)).name}")
= responder_pref_model(
resx =trial,
trial=proposer_beta,
beta_proposer=responder_beta,
beta_responder=potsize,
pot=True,
return_aux=True
return_xarray
).aux.xarrayfor pref in resx.coords.dims:
= [pref_ for pref_ in resx.coords.dims if pref_ != pref]
marginalize_over = resx.sum(dim=marginalize_over)
marginal_dist = jnp.dot(Pref, marginal_dist.values)
expectation print(f"E[{pref.lstrip('_pref_')}] = {expectation:0.3f}")
:: Proposer's inference of responder's prefs ::
Trial 0 -- Proposal: 0.30 & Response: REJECT
E[money] = 0.396
E[dia] = 0.646
E[aia] = 0.500
Trial 1 -- Proposal: 0.40 & Response: REJECT
E[money] = 0.224
E[dia] = 0.692
E[aia] = 0.500
Trial 2 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.394
E[dia] = 0.737
E[aia] = 0.267
Trial 3 -- Proposal: 0.50 & Response: REJECT
E[money] = 0.163
E[dia] = 0.676
E[aia] = 0.170
Trial 4 -- Proposal: 0.60 & Response: ACCEPT
E[money] = 0.216
E[dia] = 0.692
E[aia] = 0.165
Inference of proposer’s preferences
Now let’s model the responder inferring the proposer’s preferences. This will involve using the responder_pref_model
model from above.
@memo
def proposer_pref_model[
_pref_money: Pref,
_pref_dia: Pref,
_pref_aia: Pref=1.0, beta_responder=1.0, pot=10.0):
](trial, beta_proposer"""
model of a rational observer's (e.g., the responder's) inference about a proposer's preferences
"""
### responder's estimates of the proposer's mental contents
responder: thinks[
### (responder's inference of) proposer's prefs based on proposer's behavior in all previous trials
proposer: given(in Pref,
pref_money in Pref,
pref_dia in Pref,
pref_aia =proposer_pref_model[
wpp
pref_money,
pref_dia,
pref_aia,
](- 1,
trial
beta_proposer,
beta_responder,
potif trial > 0 else preference_jointprior_pdfjit(pref_money, pref_dia, pref_aia)
)
),
### (responder's estimate of) proposer's inference of the responder's preferences,
### based on all previous trials (calls responder_pref_model() defined above)
proposer: thinks[
responder: given(in Pref,
pref_money in Pref,
pref_dia in Pref,
pref_aia =responder_pref_model[
wpp
pref_money,
pref_dia,
pref_aia,
](- 1,
trial
beta_proposer,
beta_responder,
potif trial > 0 else preference_jointprior_pdfjit(pref_money, pref_dia, pref_aia)
)
),
],
### (responder's prediction of) proposer's choice of which split to propose
### softmax choice based on expected utility of a proposal, marginalizing over
### (the responder's estimate of the proposer's estimates about) the responder's decision.
in Proposal, wpp=exp(beta_proposer * imagine[
proposer: chooses(proposal
responder: knows(proposal),
### (responder's reasoning about) proposers' guess about the probability
### that the responder will choose `response` if the proposer chooses
### `proposal`.
in Response, wpp=exp(beta_responder * (
responder: chooses(response * val_money_responder(proposal, response, pot)
pref_money - pref_dia * val_di_responder(proposal, response, pot)
- pref_aia * val_ai_responder(proposal, response, pot)
))),
### expected utility to the proposer of the (`proposal`, `response`)
### pair under consideration
E[ * val_money_proposer(proposal, responder.response, pot)
pref_money - pref_dia * val_di_proposer(proposal, responder.response, pot)
- pref_aia * val_ai_proposer(proposal, responder.response, pot)
]
]))
]
### responder updates beliefs about proposer's preferences by observing what proposal
### the proposer actually made in this trial
== get_proposal(trial)]
responder: observes_that[proposer.proposal
### return the posterior belief about the proposer's preferences
responder: knows(_pref_money, _pref_dia, _pref_aia)return responder[Pr[
== _pref_money,
proposer.pref_money == _pref_dia,
proposer.pref_dia == _pref_aia
proposer.pref_aia
]]
3).shape proposer_pref_model(
(5, 5, 5)
Let’s see what the responder thinks the proposer’s preferences are, given a sequences of rounds.
### some imaginary data
= jnp.array([Proposal[3], Proposal[3], Proposal[3], Proposal[3]])
proposals = jnp.array([Response.ACCEPT, Response.ACCEPT, Response.ACCEPT, Response.ACCEPT])
responses
@jax.jit
def get_proposal(t):
return proposals[t]
@jax.jit
def get_response(t):
return responses[t]
= 1.0
proposer_beta = 1.0
responder_beta = 10
potsize
print(f":: Responder's inference of proposer's prefs ::")
assert proposals.size == responses.size
for trial in range(proposals.size):
print(
f"\nTrial {trial} -- ",
f"Proposal: {get_proposal(trial):0.2f} &",
f"Response: {Response(get_response(trial)).name}")
= proposer_pref_model(
resx =trial,
trial=proposer_beta,
beta_proposer=responder_beta,
beta_responder=potsize,
pot=True,
return_aux=True
return_xarray
).aux.xarrayfor pref in resx.coords.dims:
= [pref_ for pref_ in resx.coords.dims if pref_ != pref]
marginalize_over = resx.sum(dim=marginalize_over)
marginal_dist = jnp.dot(Pref, marginal_dist.values)
expectation print(f"E[{pref.lstrip('_pref_')}] = {expectation:0.3f}")
:: Responder's inference of proposer's prefs ::
Trial 0 -- Proposal: 0.30 & Response: ACCEPT
E[money] = 0.472
E[dia] = 0.551
E[aia] = 0.374
Trial 1 -- Proposal: 0.30 & Response: ACCEPT
E[money] = 0.524
E[dia] = 0.580
E[aia] = 0.237
Trial 2 -- Proposal: 0.30 & Response: ACCEPT
E[money] = 0.599
E[dia] = 0.591
E[aia] = 0.148
Trial 3 -- Proposal: 0.30 & Response: ACCEPT
E[money] = 0.671
E[dia] = 0.595
E[aia] = 0.099
loop over trials
### some imaginary data
= jnp.array([Proposal[-3], Proposal[-3], Proposal[-3], Proposal[-3]])
proposals = jnp.array([Response.ACCEPT, Response.ACCEPT, Response.ACCEPT, Response.ACCEPT])
responses
@jax.jit
def get_proposal(t):
return proposals[t]
@jax.jit
def get_response(t):
return responses[t]
= 1.0
proposer_beta = 1.0
responder_beta = 10
potsize
print(f":: Responder's inference of proposer's prefs ::")
assert proposals.size == responses.size
for trial in range(proposals.size):
print(
f"\nTrial {trial} -- ",
f"Proposal: {get_proposal(trial):0.2f} &",
f"Response: {Response(get_response(trial)).name}")
= proposer_pref_model(
resx =trial,
trial=proposer_beta,
beta_proposer=responder_beta,
beta_responder=potsize,
pot=True,
return_aux=True
return_xarray
).aux.xarrayfor pref in resx.coords.dims:
= [pref_ for pref_ in resx.coords.dims if pref_ != pref]
marginalize_over = resx.sum(dim=marginalize_over)
marginal_dist = jnp.dot(Pref, marginal_dist.values)
expectation print(f"E[{pref.lstrip('_pref_')}] = {expectation:0.3f}")
:: Responder's inference of proposer's prefs ::
Trial 0 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.301
E[dia] = 0.208
E[aia] = 0.566
Trial 1 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.195
E[dia] = 0.070
E[aia] = 0.617
Trial 2 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.137
E[dia] = 0.030
E[aia] = 0.660
Trial 3 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.102
E[dia] = 0.014
E[aia] = 0.697
loop over trials
### some imaginary data
= jnp.array([Proposal[4], Proposal[3], Proposal[6], Proposal[8]])
proposals = jnp.array([Response.ACCEPT, Response.REJECT, Response.ACCEPT, Response.REJECT])
responses
@jax.jit
def get_proposal(t):
return proposals[t]
@jax.jit
def get_response(t):
return responses[t]
= 1.0
proposer_beta = 1.0
responder_beta = 10
potsize
print(f":: Responder's inference of proposer's prefs ::")
assert proposals.size == responses.size
for trial in range(proposals.size):
print(
f"\nTrial {trial} -- ",
f"Proposal: {get_proposal(trial):0.2f} &",
f"Response: {Response(get_response(trial)).name}")
= proposer_pref_model(
resx =trial,
trial=proposer_beta,
beta_proposer=responder_beta,
beta_responder=potsize,
pot=True,
return_aux=True
return_xarray
).aux.xarrayfor pref in resx.coords.dims:
= [pref_ for pref_ in resx.coords.dims if pref_ != pref]
marginalize_over = resx.sum(dim=marginalize_over)
marginal_dist = jnp.dot(Pref, marginal_dist.values)
expectation print(f"E[{pref.lstrip('_pref_')}] = {expectation:0.3f}")
:: Responder's inference of proposer's prefs ::
Trial 0 -- Proposal: 0.40 & Response: ACCEPT
E[money] = 0.589
E[dia] = 0.546
E[aia] = 0.413
Trial 1 -- Proposal: 0.30 & Response: REJECT
E[money] = 0.581
E[dia] = 0.579
E[aia] = 0.280
Trial 2 -- Proposal: 0.60 & Response: ACCEPT
E[money] = 0.584
E[dia] = 0.445
E[aia] = 0.320
Trial 3 -- Proposal: 0.80 & Response: REJECT
E[money] = 0.458
E[dia] = 0.168
E[aia] = 0.343
loop over trials
### some imaginary data
= jnp.array([Proposal[3], Proposal[4], Proposal[8], Proposal[5], Proposal[6]])
proposals = jnp.array([Response.REJECT, Response.REJECT, Response.ACCEPT, Response.REJECT, Response.ACCEPT])
responses
@jax.jit
def get_proposal(t):
return proposals[t]
@jax.jit
def get_response(t):
return responses[t]
= 1.0
proposer_beta = 1.0
responder_beta = 10
potsize
print(f":: Responder's inference of proposer's prefs ::")
assert proposals.size == responses.size
for trial in range(proposals.size):
print(
f"\nTrial {trial} -- ",
f"Proposal: {get_proposal(trial):0.2f} &",
f"Response: {Response(get_response(trial)).name}")
= proposer_pref_model(
resx =trial,
trial=proposer_beta,
beta_proposer=responder_beta,
beta_responder=potsize,
pot=True,
return_aux=True
return_xarray
).aux.xarrayfor pref in resx.coords.dims:
= [pref_ for pref_ in resx.coords.dims if pref_ != pref]
marginalize_over = resx.sum(dim=marginalize_over)
marginal_dist = jnp.dot(Pref, marginal_dist.values)
expectation print(f"E[{pref.lstrip('_pref_')}] = {expectation:0.3f}")
:: Responder's inference of proposer's prefs ::
Trial 0 -- Proposal: 0.30 & Response: REJECT
E[money] = 0.472
E[dia] = 0.551
E[aia] = 0.374
Trial 1 -- Proposal: 0.40 & Response: REJECT
E[money] = 0.542
E[dia] = 0.586
E[aia] = 0.299
Trial 2 -- Proposal: 0.80 & Response: ACCEPT
E[money] = 0.424
E[dia] = 0.404
E[aia] = 0.321
Trial 3 -- Proposal: 0.50 & Response: REJECT
E[money] = 0.550
E[dia] = 0.439
E[aia] = 0.345
Trial 4 -- Proposal: 0.60 & Response: ACCEPT
E[money] = 0.600
E[dia] = 0.357
E[aia] = 0.359
References
%reset -f
import sys
import platform
import importlib.metadata
print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())
print("\nPackages:")
for name, version in sorted(
"Name"], dist.version) for dist in importlib.metadata.distributions()),
((dist.metadata[=lambda x: x[0].lower() # Sort case-insensitively
key
):print(f"{name}=={version}")
Python: 3.13.2 (main, Feb 5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64
Packages:
annotated-types==0.7.0
anyio==4.9.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
astroid==3.3.9
asttokens==3.0.0
async-lru==2.0.5
attrs==25.3.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
dill==0.3.9
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.18.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.9
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.2
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
isort==6.0.1
jax==0.5.3
jaxlib==0.5.3
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.6
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
mccabe==0.7.0
memo-lang==1.1.2
mistune==3.1.3
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.4
opt_einsum==3.4.0
optype==0.9.2
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.250308
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.7
plotly==5.24.1
pre_commit==4.2.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pylint==3.3.6
pyparsing==3.2.3
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-json-logger==3.3.0
pytz==2025.2
PyYAML==6.0.2
pyzmq==26.3.0
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.11.2
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.1
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==78.1.0
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.39
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.6.0
tinycss2==1.4.0
toml==0.10.2
tomlkit==0.13.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250318
typing_extensions==4.12.2
tzdata==2025.2
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.3.0
zipp==3.21.0