import jax
import jax.numpy as jnp
from memo import memo
from memo import domain as product
from enum import IntEnum
from matplotlib import pyplot as plt
from jax.scipy.stats.norm import pdf as normpdf
= jax.jit(normpdf)
normpdfjit
= jnp.linspace(0, 1, 10+1, endpoint=True)
PolicePerformance
= jnp.linspace(-1, 1, 40+1, endpoint=True)
Causal
= jnp.linspace(0, 1, 10+1, endpoint=True)
Arrests
@jax.jit
def arrests_pdf(arrests, performance, causal_link):
= causal_link * (performance - 0.5) + 0.5
arrests_mu = 0.1
arrests_sigma return normpdf(arrests, arrests_mu, arrests_sigma)
@jax.jit
def reported_cl_pdf(reported, real, bias=0.0):
return normpdf(reported, real + real*bias, 2.0)
@memo
def viewerModel[
_prior_expectation_cl: Causal,
](
reported_cl_observed,
nobs,
):
viewer: knows(
_prior_expectation_cl,
)
viewer: thinks[in Causal, wpp=(
police: given(causal_link - 1)
viewerModel[causal_link](reported_cl_observed, nobs if nobs > 0 else 1)),
in PolicePerformance, wpp=1),
police: chooses(performance in Arrests, wpp=arrests_pdf(arrests, performance, causal_link)),
police: chooses(arrests
news: knows(police.causal_link),in Causal, wpp=(
news: chooses(reported_cl
reported_cl_pdf(reported_cl, police.causal_link)
)),
]=normpdfjit(news.reported_cl, reported_cl_observed, 0.2))
viewer: observes_event(wpp
return viewer[
Pr[== _prior_expectation_cl
police.causal_link
] ]
Belief polarization
world building
What we learn from new information is mediated by the mental models we use to interpret the data.
In priors and explanation we saw how prior beliefs can lead to different explanations of the same observation — two people with different priors might explain the same event in radically different ways.
In that model, the two observers update their beliefs in accordance with Bayes’ rule. Because the observers had different prior beliefs, they inferred different explanations for the same data. However, while their explanations were different, both observers updated their beliefs in the same direction, but to different degrees. As more relevant information accumulates, the consequence of priors should diminish. This idea underpins the optimism of rational discourse: shared evidence should lead to shared understanding.
But this doesn’t always seem to happen. Sometimes people will have more polarized beliefs after observing the same information. How is this possible? Is there a “rational” explanation?
This polarization can arise because we do not merely learn patterns of data; we learn how to interpret data. Each new observation can reshape the underlying causal models through which we interpret future information. Humans excel at integrating prior knowledge with observations, but “rational” does not necessitate “correct”.
= Causal[4]
reported for nobs_ in range(10):
= viewerModel(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res_viewer = res_viewer.aux.xarray
xardata print(f"nobs: {nobs_+1} E = {jnp.dot(xardata['_prior_expectation_cl'].values, xardata.values)}")
nobs: 1 E = -0.06499627977609634
nobs: 2 E = -0.12871907651424408
nobs: 3 E = -0.1903264820575714
nobs: 4 E = -0.2491239309310913
nobs: 5 E = -0.3045755922794342
nobs: 6 E = -0.3563281297683716
nobs: 7 E = -0.4041990041732788
nobs: 8 E = -0.4481450915336609
nobs: 9 E = -0.48824718594551086
nobs: 10 E = -0.5246806740760803
= Causal[-4]
reported for nobs_ in range(10):
= viewerModel(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res_viewer = res_viewer.aux.xarray
xardata print(f"nobs: {nobs_+1} E = {jnp.dot(xardata['_prior_expectation_cl'].values, xardata.values)}")
nobs: 1 E = 0.06774599850177765
nobs: 2 E = 0.1340755820274353
nobs: 3 E = 0.19804811477661133
nobs: 4 E = 0.25888949632644653
nobs: 5 E = 0.3160253167152405
nobs: 6 E = 0.36908823251724243
nobs: 7 E = 0.41790270805358887
nobs: 8 E = 0.4624665677547455
nobs: 9 E = 0.5029036998748779
nobs: 10 E = 0.5394370555877686
@memo
def performanceInferred[
_arrests_observed: Arrests,
](reported_cl_observed, nobs):
viewer: knows(_arrests_observed)
viewer: thinks[in Causal, wpp=(
police: given(causal_link
viewerModel[causal_link](reported_cl_observed, nobs) if nobs > 0 else 1
)),in PolicePerformance, wpp=1),
police: chooses(performance in Arrests, wpp=arrests_pdf(arrests, performance, causal_link)),
police: chooses(arrests
news: knows(police.causal_link),in Causal, wpp=(
news: chooses(reported_cl
reported_cl_pdf(reported_cl, police.causal_link)
)),
]=normpdfjit(news.reported_cl, reported_cl_observed, 0.2))
viewer: observes_event(wppis _arrests_observed
viewer: observes [police.arrests]
return viewer[ E[police.performance] ]
= Causal[4]
reported = 4
nobs_ = performanceInferred(reported, nobs_, print_table=True) _
+----------------------------+----------------------+
| _arrests_observed: Arrests | performanceInferred |
+----------------------------+----------------------+
| 0.0 | 0.8483712673187256 |
| 0.10000000149011612 | 0.7999858856201172 |
| 0.20000000298023224 | 0.7301038503646851 |
| 0.30000001192092896 | 0.645521342754364 |
| 0.4000000059604645 | 0.5646161437034607 |
| 0.5 | 0.4999997317790985 |
| 0.6000000238418579 | 0.43538177013397217 |
| 0.699999988079071 | 0.35447949171066284 |
| 0.800000011920929 | 0.2698955237865448 |
| 0.9000000357627869 | 0.2000139355659485 |
| 1.0 | 0.15162856876850128 |
+----------------------------+----------------------+
= Causal[4].item()
fox_reported = Causal[-4].item()
msnbc_reported = Arrests[-2].item()
arrests_observed
= list(range(10))
nobs_list
= list()
foxviewer_policeperformance__arrests = list()
msnbcviewer_policeperformance__arrests for nobs_ in nobs_list:
= performanceInferred(fox_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res_fox = res_fox.aux.xarray
xardata_fox
foxviewer_policeperformance__arrests.append(xardata_fox.loc[arrests_observed].item())
= performanceInferred(msnbc_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res_msnbc_viewer = res_msnbc_viewer.aux.xarray
xardata_msnbc
msnbcviewer_policeperformance__arrests.append(xardata_msnbc.loc[arrests_observed].item())
= plt.subplots()
fig, ax ='Fox')
ax.plot(nobs_list, foxviewer_policeperformance__arrests, label='MSNBC')
ax.plot(nobs_list, msnbcviewer_policeperformance__arrests, label= ax.set_xlabel("number of newscasts viewed")
_ = ax.set_ylabel("inferred police performance")
_ = ax.set_title(f"arrests: {arrests_observed:0.3f}")
_ ax.legend()
= Causal[4].item()
fox_reported = Causal[-4].item()
msnbc_reported = Arrests[2].item()
arrests_observed
= list(range(10))
nobs_list
= list()
foxviewer_policeperformance__arrests = list()
msnbcviewer_policeperformance__arrests for nobs_ in nobs_list:
= performanceInferred(fox_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res_fox = res_fox.aux.xarray
xardata_fox
foxviewer_policeperformance__arrests.append(xardata_fox.loc[arrests_observed].item())
= performanceInferred(msnbc_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res_msnbc_viewer = res_msnbc_viewer.aux.xarray
xardata_msnbc
msnbcviewer_policeperformance__arrests.append(xardata_msnbc.loc[arrests_observed].item())
= plt.subplots()
fig, ax ='Fox')
ax.plot(nobs_list, foxviewer_policeperformance__arrests, label='MSNBC')
ax.plot(nobs_list, msnbcviewer_policeperformance__arrests, label= ax.set_xlabel("number of newscasts viewed")
_ = ax.set_ylabel("inferred police performance")
_ = ax.set_title(f"arrests: {arrests_observed:0.3f}")
_ ax.legend()
= plt.subplots()
fig, ax for reported in [Causal[4].item()]:
for nobs_ in [0, 1, 2, 4, 6, 8]:
= performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res = res.aux.xarray
xrd "_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")
ax.plot(xrd[
= ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
_ fig.legend()
= plt.subplots()
fig, ax for reported in [Causal[-4].item()]:
for nobs_ in [0, 1, 2, 4, 6, 8]:
= performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res = res.aux.xarray
xrd "_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")
ax.plot(xrd[
= ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
_ fig.legend()
= plt.subplots()
fig, ax for reported in jnp.linspace(-1, 1, 7, endpoint=True):
for nobs_ in [3]:
= performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
res = res.aux.xarray
xrd "_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")
ax.plot(xrd[
= ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
_ fig.legend()
%reset -f
import sys
import platform
import importlib.metadata
print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())
print("\nPackages:")
for name, version in sorted(
"Name"], dist.version) for dist in importlib.metadata.distributions()),
((dist.metadata[=lambda x: x[0].lower() # Sort case-insensitively
key
):print(f"{name}=={version}")
Python: 3.13.2 (main, Feb 5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64
Packages:
annotated-types==0.7.0
anyio==4.8.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
async-lru==2.0.4
attrs==25.1.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.17.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.8
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.1
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.5.2
jaxlib==0.5.1
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
memo-lang==1.1.0
mistune==3.1.2
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.3
opt_einsum==3.4.0
optype==0.9.1
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.241126
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.6
plotly==5.24.1
pre_commit==4.1.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pyparsing==3.2.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==3.3.0
pytz==2025.1
PyYAML==6.0.2
pyzmq==26.2.1
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.9.10
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.0
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==75.8.2
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.38
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.4.0
toml==0.10.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250204
typing_extensions==4.12.2
tzdata==2025.1
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.1.2
zipp==3.21.0