import jax
import jax.numpy as jnp
from memo import memo
from enum import IntEnum, auto
from matplotlib import pyplot as plt
Installation Test
Load the dependencies
Run a memo
model
class Card(IntEnum):
= auto()
ONE = auto()
TWO
@memo
def game[_c: Card]():
in Card, wpp=1)
player: chooses(c return Pr[player.c == _c]
=True) game(print_table
+----------+-------+
| _c: Card | game |
+----------+-------+
| ONE | 0.5 |
| TWO | 0.5 |
+----------+-------+
Array([0.5, 0.5], dtype=float32)
Check if pandas
is installed
=True).aux.pandas game(return_pandas
_c | ||
---|---|---|
0 | ONE | 0.5 |
1 | TWO | 0.5 |
Check if xarray
is installed
=True).aux.xarray game(return_xarray
<xarray.DataArray '' (_c: 2)> Size: 8B Array([0.5, 0.5], dtype=float32) Coordinates: * _c (_c) <U3 24B 'ONE' 'TWO'