import jax
import jax.numpy as jnp
from memo import memo
from memo import domain as product
from enum import IntEnum
from matplotlib import pyplot as plt
from jax.scipy.stats.beta import pdf as betapdf
jax.config.update("jax_platform_name", "cpu")
betapdfjit = jax.jit(betapdf)
# pot_ = 10
Proposal = jnp.linspace(0.0, 1.0, 10 + 1)
class Response(IntEnum):
REJECT = 0
ACCEPT = 1
Pref = jnp.linspace(0, 1, 5)
@jax.jit
def val_money_proposer(proposal, response, pot):
return pot * (1 - proposal) * (response == Response.ACCEPT)
@jax.jit
def val_money_responder(proposal, response, pot):
return pot * proposal * (response == Response.ACCEPT)
@jax.jit
def val_di_proposer(proposal, response, pot):
return jnp.max(jnp.array([
val_money_responder(proposal, response, pot)
- val_money_proposer(proposal, response, pot),
0
]))
@jax.jit
def val_ai_proposer(proposal, response, pot):
return jnp.max(jnp.array([
val_money_proposer(proposal, response, pot)
- val_money_responder(proposal, response, pot),
0
]))
@jax.jit
def val_di_responder(proposal, response, pot):
return jnp.max(jnp.array([
val_money_proposer(proposal, response, pot)
- val_money_responder(proposal, response, pot),
0
]))
@jax.jit
def val_ai_responder(proposal, response, pot):
return jnp.max(jnp.array([
val_money_responder(proposal, response, pot)
- val_money_proposer(proposal, response, pot),
0
]))
@memo
def response_model_v0[_proposal: Proposal, _response: Response](beta=1.0, pot=100.0):
responder: given(pref_money in Pref, pref_dia in Pref, pref_aia in Pref, wpp=1)
responder: thinks[
proposer: chooses(proposal in Proposal, wpp=1) ### oversimplification
]
responder: observes [proposer.proposal] is _proposal
responder: chooses(response in Response, wpp=exp(beta * (
pref_money * val_money_responder(proposer.proposal, response, pot)
- pref_dia * val_di_responder(proposer.proposal, response, pot)
- pref_aia * val_ai_responder(proposer.proposal, response, pot)
)))
return Pr[ responder.response == _response ]
response_model_v0(1.0, 10)
@memo
def proposal_model_v0[_proposal: Proposal](beta=1.0, p2beta=1.0, pot=100.0):
proposer: given(pref_money in Pref, pref_dia in Pref, pref_aia in Pref, wpp=1)
proposer: chooses(proposal in Proposal, wpp=exp(beta * imagine[
responder: knows(proposal),
responder: chooses(
response in Response,
wpp=response_model_v0[
proposal,
response
](p2beta, pot)),
E[
pref_money * val_money_proposer(proposal, responder.response, pot)
- pref_dia * val_di_proposer(proposal, responder.response, pot)
- pref_aia * val_ai_proposer(proposal, responder.response, pot)
]
]))
return Pr[ proposer.proposal == _proposal ]
proposal_model_v0(1.0, 1.0, 10)
Array([[0.8833704 , 0.11662987],
[0.8313682 , 0.16863157],
[0.74316436, 0.2568355 ],
[0.5780735 , 0.4219264 ],
[0.3393453 , 0.66065454],
[0.16564558, 0.83435374],
[0.25683576, 0.7431642 ],
[0.33988732, 0.66011244],
[0.40628162, 0.59371847],
[0.45854715, 0.541453 ],
[0.5 , 0.4999999 ]], dtype=float32)
Array([0.05491067, 0.05632582, 0.06254137, 0.08616571, 0.17013934,
0.369276 , 0.09633861, 0.04321389, 0.02648459, 0.01922949,
0.01537458], dtype=float32)