import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum, auto
from matplotlib import pyplot as pltInstallation Test
Load the dependencies
Run a memo model
class Card(IntEnum):
ONE = auto()
TWO = auto()
@memo
def game[_c: Card]():
player: chooses(c in Card, wpp=1)
return Pr[player.c == _c]
game(print_table=True)+----------+-------+
| _c: Card | game |
+----------+-------+
| ONE | 0.5 |
| TWO | 0.5 |
+----------+-------+
Array([0.5, 0.5], dtype=float32)
Check if pandas is installed
game(return_pandas=True).aux.pandas| _c | ||
|---|---|---|
| 0 | ONE | 0.5 |
| 1 | TWO | 0.5 |
Check if xarray is installed
game(return_xarray=True).aux.xarray<xarray.DataArray '' (_c: 2)> Size: 8B Array([0.5, 0.5], dtype=float32) Coordinates: * _c (_c) <U3 24B 'ONE' 'TWO'