Ultimatum Game

planing and inverse planning over an interated game

One-Shot Ultimatum Game model involving a proposer with 3 preferences who represents a responder with 3 preferences

import jax
import jax.numpy as jnp
from memo import memo
from memo import domain as product
from enum import IntEnum
from matplotlib import pyplot as plt
from jax.scipy.stats.beta import pdf as betapdf

jax.config.update("jax_platform_name", "cpu")


betapdfjit = jax.jit(betapdf)

# pot_ = 10

Proposal = jnp.linspace(0.0, 1.0, 10 + 1)

class Response(IntEnum):
    REJECT = 0
    ACCEPT = 1

Pref = jnp.linspace(0, 1, 5)

@jax.jit
def val_money_proposer(proposal, response, pot):
    return pot * (1 - proposal) * (response == Response.ACCEPT)

@jax.jit
def val_money_responder(proposal, response, pot):
    return pot * proposal * (response == Response.ACCEPT)

@jax.jit
def val_di_proposer(proposal, response, pot):
    return jnp.max(jnp.array([
        val_money_responder(proposal, response, pot) 
        - val_money_proposer(proposal, response, pot), 
        0
    ]))

@jax.jit
def val_ai_proposer(proposal, response, pot):
    return jnp.max(jnp.array([
        val_money_proposer(proposal, response, pot) 
        - val_money_responder(proposal, response, pot), 
        0
    ]))

@jax.jit
def val_di_responder(proposal, response, pot):
    return jnp.max(jnp.array([
        val_money_proposer(proposal, response, pot) 
        - val_money_responder(proposal, response, pot), 
        0
    ]))

@jax.jit
def val_ai_responder(proposal, response, pot):
    return jnp.max(jnp.array([
        val_money_responder(proposal, response, pot) 
        - val_money_proposer(proposal, response, pot), 
        0
    ]))

@memo
def response_model_v0[_proposal: Proposal, _response: Response](beta=1.0, pot=100.0):
    responder: given(pref_money in Pref, pref_dia in Pref, pref_aia in Pref, wpp=1)
    responder: thinks[
        proposer: chooses(proposal in Proposal, wpp=1)  ### oversimplification
    ]
    responder: observes [proposer.proposal] is _proposal
    responder: chooses(response in Response, wpp=exp(beta * ( 
            pref_money * val_money_responder(proposer.proposal, response, pot)
            - pref_dia * val_di_responder(proposer.proposal, response, pot)
            - pref_aia * val_ai_responder(proposer.proposal, response, pot)
    )))
    return Pr[ responder.response == _response ]


response_model_v0(1.0, 10)

@memo
def proposal_model_v0[_proposal: Proposal](beta=1.0, p2beta=1.0, pot=100.0):
    proposer: given(pref_money in Pref, pref_dia in Pref, pref_aia in Pref, wpp=1)
    proposer: chooses(proposal in Proposal, wpp=exp(beta * imagine[
        responder: knows(proposal),
        responder: chooses(
            response in Response,
            wpp=response_model_v0[
                proposal, 
                response
            ](p2beta, pot)),
        E[ 
            pref_money * val_money_proposer(proposal, responder.response, pot)
            - pref_dia * val_di_proposer(proposal, responder.response, pot)
            - pref_aia * val_ai_proposer(proposal, responder.response, pot)
        ]
    ]))
    return Pr[ proposer.proposal == _proposal ]

proposal_model_v0(1.0, 1.0, 10)
Array([[0.8833704 , 0.11662987],
       [0.8313682 , 0.16863157],
       [0.74316436, 0.2568355 ],
       [0.5780735 , 0.4219264 ],
       [0.3393453 , 0.66065454],
       [0.16564558, 0.83435374],
       [0.25683576, 0.7431642 ],
       [0.33988732, 0.66011244],
       [0.40628162, 0.59371847],
       [0.45854715, 0.541453  ],
       [0.5       , 0.4999999 ]], dtype=float32)
Array([0.05491067, 0.05632582, 0.06254137, 0.08616571, 0.17013934,
       0.369276  , 0.09633861, 0.04321389, 0.02648459, 0.01922949,
       0.01537458], dtype=float32)

Now let’s put priors on the proposer’s preferences, and what the proposer thinks the responder’s preferences are.

@memo
def response_model[
    _proposal: Proposal, 
    _response: Response, 
    _pref_money: Pref, 
    _pref_dia: Pref, 
    _pref_aia: Pref
](
    beta=1.0, 
    pot=100.0
):
    # responder: given(pref_money in Pref, wpp=betapdfjit(pref_money, 1, 1))
    # responder: given(pref_dia in Pref, wpp=betapdfjit(pref_dia, 1, 1))
    # responder: given(pref_aia in Pref, wpp=betapdfjit(pref_aia, 1, 1))
    responder: knows(_pref_money, _pref_dia, _pref_aia)
    responder: thinks[
        proposer: chooses(proposal in Proposal, wpp=1)
    ]
    responder: observes [proposer.proposal] is _proposal
    responder: chooses(response in Response, wpp=exp(beta * E[ 
            _pref_money * val_money_responder(proposer.proposal, response, pot)
            - _pref_dia * val_di_responder(proposer.proposal, response, pot)
            - _pref_aia * val_ai_responder(proposer.proposal, response, pot)
        ]
    ))
    return Pr[ responder.response == _response ]


# response_model(1.0, 10)

@memo
def proposal_model[_proposal: Proposal](beta=1.0, p2beta=1.0, pot=100.0):
    proposer: given(pref_money in Pref, wpp=betapdfjit(pref_money, 1, 1))
    proposer: given(pref_dia in Pref, wpp=betapdfjit(pref_dia, 1, 1))
    proposer: given(pref_aia in Pref, wpp=betapdfjit(pref_aia, 1, 1))
    proposer: thinks[
        responder: given(pref_money in Pref, wpp=betapdfjit(pref_money, 1, 1)),
        responder: given(pref_dia in Pref, wpp=betapdfjit(pref_dia, 1, 1)),
        responder: given(pref_aia in Pref, wpp=betapdfjit(pref_aia, 1, 1)),
    ]
    proposer: chooses(proposal in Proposal, wpp=exp(beta * imagine[
        responder: knows(proposal),
        responder: chooses(response in Response, wpp=response_model[
            proposal, 
            response,
            pref_money,
            pref_dia,
            pref_aia,
        ](p2beta, pot)),
        E[ 
            pref_money * val_money_proposer(proposal, responder.response, pot)
            - pref_dia * val_di_proposer(proposal, responder.response, pot)
            - pref_aia * val_ai_proposer(proposal, responder.response, pot)
        ]
    ]))
    return Pr[ proposer.proposal == _proposal ]

# proposal_model(1.0, 1.0, 10)
Warning: Redundant expectation, not marginalizing...
|     responder: chooses(response in Response, wpp=exp(beta * E[ 
|                                                             ^

Now let’s model what the proposer infers the responder’s preferences are, given a sequences of rounds.

### some imaginary data
proposals = jnp.array([Proposal[3], Proposal[3], Proposal[3], Proposal[3]])
responses = jnp.array([Response.ACCEPT, Response.ACCEPT, Response.ACCEPT, Response.ACCEPT])

@jax.jit
def get_proposal(t):
    return proposals[t]

@jax.jit
def get_response(t):
    return responses[t]

@memo
def responder_pref_model[
    _pref_money: Pref, 
    _pref_dia: Pref, 
    _pref_aia: Pref
](t, beta=1.0, beta_responder=1.0, pot=100.0):
    proposer: knows(_pref_money, _pref_dia, _pref_aia)
    proposer: thinks[
        responder: given(
            pref_money in Pref, 
            pref_dia in Pref, 
            pref_aia in Pref, 
            wpp=responder_pref_model[
                pref_money, 
                pref_dia, 
                pref_aia,
            ](
                t - 1, 
                beta, 
                beta_responder, 
                pot
            ) if t > 0 else (
                betapdfjit(pref_money, 1, 1)
                * betapdfjit(pref_dia, 1, 1)
                * betapdfjit(pref_aia, 1, 1)
            )),
    ]
    proposer: chooses(proposal in Proposal, wpp=proposal == get_proposal(t))
    proposer: thinks[
        responder: knows(proposal),
        responder: chooses(response in Response, wpp=response_model[
            proposal, 
            response,
            pref_money,
            pref_dia,
            pref_aia,
        ](beta_responder, pot))
    ]
    proposer: observes_that[responder.response == get_response(t)]
    proposer: thinks[
        responder: knows(_pref_money, _pref_dia, _pref_aia)
    ]
    return E[proposer[Pr[ 
        responder.pref_money == _pref_money, 
        responder.pref_dia == _pref_dia, 
        responder.pref_aia == _pref_aia 
    ]]]


trial = 4
proposer_beta = 1.0
responder_beta = 1.0
potsize = 10
res = responder_pref_model(t=trial, beta=proposer_beta, beta_responder=responder_beta, pot=potsize, return_aux=True, return_xarray=True)
resx = res.aux.xarray

for trial in range(4):
    print(f"\nTrial {trial} -- Proposal: {get_proposal(trial):0.2f} & Response: {Response(get_response(trial)).name}")
    resx = responder_pref_model(t=trial, beta=proposer_beta, beta_responder=responder_beta, pot=potsize, return_aux=True, return_xarray=True).aux.xarray
    for pref in resx.coords.dims:
        marginalize_over = [pref_ for pref_ in resx.coords.dims if pref_ != pref]
        marginal_dist = resx.sum(dim=marginalize_over)
        expectation = jnp.dot(Pref, marginal_dist.values)
        print(f"E[{pref.lstrip('_pref_')}|t={trial}] = {expectation:0.3f}")

Trial 0 -- Proposal: 0.30 & Response: ACCEPT
E[money|t=0] = 0.642
E[dia|t=0] = 0.300
E[aia|t=0] = 0.500

Trial 1 -- Proposal: 0.30 & Response: ACCEPT
E[money|t=1] = 0.707
E[dia|t=1] = 0.215
E[aia|t=1] = 0.500

Trial 2 -- Proposal: 0.30 & Response: ACCEPT
E[money|t=2] = 0.749
E[dia|t=2] = 0.171
E[aia|t=2] = 0.500

Trial 3 -- Proposal: 0.30 & Response: ACCEPT
E[money|t=3] = 0.779
E[dia|t=3] = 0.144
E[aia|t=3] = 0.500
### some imaginary data
proposals = jnp.array([Proposal[-3], Proposal[-3], Proposal[-3], Proposal[-3]])
responses = jnp.array([Response.ACCEPT, Response.ACCEPT, Response.ACCEPT, Response.ACCEPT])

@jax.jit
def get_proposal(t):
    return proposals[t]

@jax.jit
def get_response(t):
    return responses[t]

trial = 4
proposer_beta = 1.0
responder_beta = 1.0
potsize = 10
res = responder_pref_model(t=trial, beta=proposer_beta, beta_responder=responder_beta, pot=potsize, return_aux=True, return_xarray=True)
resx = res.aux.xarray

for trial in range(4):
    print(f"\nTrial {trial} -- Proposal: {get_proposal(trial):0.2f} & Response: {Response(get_response(trial)).name}")
    resx = responder_pref_model(t=trial, beta=proposer_beta, beta_responder=responder_beta, pot=potsize, return_aux=True, return_xarray=True).aux.xarray
    for pref in resx.coords.dims:
        marginalize_over = [pref_ for pref_ in resx.coords.dims if pref_ != pref]
        marginal_dist = resx.sum(dim=marginalize_over)
        expectation = jnp.dot(Pref, marginal_dist.values)
        print(f"E[{pref.lstrip('_pref_')}|t={trial}] = {expectation:0.3f}")

Trial 0 -- Proposal: 0.80 & Response: ACCEPT
E[money|t=0] = 0.678
E[dia|t=0] = 0.500
E[aia|t=0] = 0.380

Trial 1 -- Proposal: 0.80 & Response: ACCEPT
E[money|t=1] = 0.727
E[dia|t=1] = 0.500
E[aia|t=1] = 0.353

Trial 2 -- Proposal: 0.80 & Response: ACCEPT
E[money|t=2] = 0.752
E[dia|t=2] = 0.500
E[aia|t=2] = 0.338

Trial 3 -- Proposal: 0.80 & Response: ACCEPT
E[money|t=3] = 0.769
E[dia|t=3] = 0.500
E[aia|t=3] = 0.326
### some imaginary data
proposals = jnp.array([Proposal[4], Proposal[3], Proposal[6], Proposal[8]])
responses = jnp.array([Response.ACCEPT, Response.REJECT, Response.ACCEPT, Response.REJECT])

@jax.jit
def get_proposal(t):
    return proposals[t]

@jax.jit
def get_response(t):
    return responses[t]

trial = 4
proposer_beta = 1.0
responder_beta = 1.0
potsize = 10

for trial in range(4):
    print(f"\nTrial {trial} -- Proposal: {get_proposal(trial):0.2f} & Response: {Response(get_response(trial)).name}")
    resx = responder_pref_model(t=trial, beta=proposer_beta, beta_responder=responder_beta, pot=potsize, return_aux=True, return_xarray=True).aux.xarray
    for pref in resx.coords.dims:
        marginalize_over = [pref_ for pref_ in resx.coords.dims if pref_ != pref]
        marginal_dist = resx.sum(dim=marginalize_over)
        expectation = jnp.dot(Pref, marginal_dist.values)
        print(f"E[{pref.lstrip('_pref_')}|t={trial}] = {expectation:0.3f}")

Trial 0 -- Proposal: 0.40 & Response: ACCEPT
E[money|t=0] = 0.626
E[dia|t=0] = 0.441
E[aia|t=0] = 0.500

Trial 1 -- Proposal: 0.30 & Response: REJECT
E[money|t=1] = 0.539
E[dia|t=1] = 0.608
E[aia|t=1] = 0.500

Trial 2 -- Proposal: 0.60 & Response: ACCEPT
E[money|t=2] = 0.631
E[dia|t=2] = 0.641
E[aia|t=2] = 0.467

Trial 3 -- Proposal: 0.80 & Response: REJECT
E[money|t=3] = 0.401
E[dia|t=3] = 0.564
E[aia|t=3] = 0.696

References


%reset -f
import sys
import platform
import importlib.metadata

print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())

print("\nPackages:")
for name, version in sorted(
    ((dist.metadata["Name"], dist.version) for dist in importlib.metadata.distributions()),
    key=lambda x: x[0].lower()  # Sort case-insensitively
):
    print(f"{name}=={version}")
Python: 3.13.2 (main, Feb  5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64

Packages:
annotated-types==0.7.0
anyio==4.9.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
astroid==3.3.9
asttokens==3.0.0
async-lru==2.0.5
attrs==25.3.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
dill==0.3.9
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.18.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.9
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.2
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
isort==6.0.1
jax==0.5.3
jaxlib==0.5.3
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.6
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
mccabe==0.7.0
memo-lang==1.1.2
mistune==3.1.3
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.4
opt_einsum==3.4.0
optype==0.9.2
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.250308
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.7
plotly==5.24.1
pre_commit==4.2.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pylint==3.3.6
pyparsing==3.2.3
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-json-logger==3.3.0
pytz==2025.2
PyYAML==6.0.2
pyzmq==26.3.0
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.11.2
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.1
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==78.1.0
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.39
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.6.0
tinycss2==1.4.0
toml==0.10.2
tomlkit==0.13.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250318
typing_extensions==4.12.2
tzdata==2025.2
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.3.0
zipp==3.21.0