Belief polarization

world building

What we learn from new information is mediated by the mental models we use to interpret the data.

In priors and explanation we saw how prior beliefs can lead to different explanations of the same observation — two people with different priors might explain the same event in radically different ways.

In that model, the two observers update their beliefs in accordance with Bayes’ rule. Because the observers had different prior beliefs, they inferred different explanations for the same data. However, while their explanations were different, both observers updated their beliefs in the same direction, but to different degrees. As more relevant information accumulates, the consequence of priors should diminish. This idea underpins the optimism of rational discourse: shared evidence should lead to shared understanding.

But this doesn’t always seem to happen. Sometimes people will have more polarized beliefs after observing the same information. How is this possible? Is there a “rational” explanation?

This polarization can arise because we do not merely learn patterns of data; we learn how to interpret data. Each new observation can reshape the underlying causal models through which we interpret future information. Humans excel at integrating prior knowledge with observations, but “rational” does not necessitate “correct”.

import jax
import jax.numpy as jnp
from memo import memo
from memo import domain as product
from enum import IntEnum
from matplotlib import pyplot as plt
from jax.scipy.stats.norm import pdf as normpdf

normpdfjit = jax.jit(normpdf)

PolicePerformance = jnp.linspace(0, 1, 10+1, endpoint=True)

Causal = jnp.linspace(-1, 1, 40+1, endpoint=True)

Arrests = jnp.linspace(0, 1, 10+1, endpoint=True)

@jax.jit
def arrests_pdf(arrests, performance, causal_link):
    arrests_mu = causal_link * (performance - 0.5) + 0.5
    arrests_sigma = 0.1
    return normpdf(arrests, arrests_mu, arrests_sigma)

@jax.jit
def reported_cl_pdf(reported, real, bias=0.0):
    return normpdf(reported, real + real*bias, 2.0)

@memo
def viewerModel[
    _prior_expectation_cl: Causal,
](
    reported_cl_observed, 
    nobs,
):
    viewer: knows(
        _prior_expectation_cl,
    )
    viewer: thinks[
        police: given(causal_link in Causal, wpp=(
            viewerModel[causal_link](reported_cl_observed, nobs - 1) 
            if nobs > 0 else 1)),
        police: chooses(performance in PolicePerformance, wpp=1),
        police: chooses(arrests in Arrests, wpp=arrests_pdf(arrests, performance, causal_link)),
        news: knows(police.causal_link),
        news: chooses(reported_cl in Causal, wpp=(
            reported_cl_pdf(reported_cl, police.causal_link)
        )),
    ]
    viewer: observes_event(wpp=normpdfjit(news.reported_cl, reported_cl_observed, 0.2))

    return viewer[ 
        Pr[
            police.causal_link == _prior_expectation_cl
        ] 
    ]
reported = Causal[4]
for nobs_ in range(10):
    res_viewer = viewerModel(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata = res_viewer.aux.xarray
    print(f"nobs: {nobs_+1} E = {jnp.dot(xardata['_prior_expectation_cl'].values, xardata.values)}")
nobs: 1 E = -0.06499627977609634
nobs: 2 E = -0.12871907651424408
nobs: 3 E = -0.1903264820575714
nobs: 4 E = -0.2491239309310913
nobs: 5 E = -0.3045755922794342
nobs: 6 E = -0.3563281297683716
nobs: 7 E = -0.4041990041732788
nobs: 8 E = -0.4481450915336609
nobs: 9 E = -0.48824718594551086
nobs: 10 E = -0.5246806740760803
reported = Causal[-4]
for nobs_ in range(10):
    res_viewer = viewerModel(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata = res_viewer.aux.xarray
    print(f"nobs: {nobs_+1} E = {jnp.dot(xardata['_prior_expectation_cl'].values, xardata.values)}")
nobs: 1 E = 0.06774599850177765
nobs: 2 E = 0.1340755820274353
nobs: 3 E = 0.19804811477661133
nobs: 4 E = 0.25888949632644653
nobs: 5 E = 0.3160253167152405
nobs: 6 E = 0.36908823251724243
nobs: 7 E = 0.41790270805358887
nobs: 8 E = 0.4624665677547455
nobs: 9 E = 0.5029036998748779
nobs: 10 E = 0.5394370555877686
@memo
def performanceInferred[
    _arrests_observed: Arrests, 
](reported_cl_observed, nobs):
    viewer: knows(_arrests_observed)
    viewer: thinks[
        police: given(causal_link in Causal, wpp=(
            viewerModel[causal_link](reported_cl_observed, nobs) 
            if nobs > 0 else 1
        )),
        police: chooses(performance in PolicePerformance, wpp=1),
        police: chooses(arrests in Arrests, wpp=arrests_pdf(arrests, performance, causal_link)),
        news: knows(police.causal_link),
        news: chooses(reported_cl in Causal, wpp=(
            reported_cl_pdf(reported_cl, police.causal_link)
        )),
    ]
    viewer: observes_event(wpp=normpdfjit(news.reported_cl, reported_cl_observed, 0.2))
    viewer: observes [police.arrests] is _arrests_observed

    return viewer[ E[police.performance] ]

reported = Causal[4]
nobs_ = 4
_ = performanceInferred(reported, nobs_, print_table=True)
+----------------------------+----------------------+
| _arrests_observed: Arrests | performanceInferred  |
+----------------------------+----------------------+
| 0.0                        | 0.8483712673187256   |
| 0.10000000149011612        | 0.7999858856201172   |
| 0.20000000298023224        | 0.7301038503646851   |
| 0.30000001192092896        | 0.645521342754364    |
| 0.4000000059604645         | 0.5646161437034607   |
| 0.5                        | 0.4999997317790985   |
| 0.6000000238418579         | 0.43538177013397217  |
| 0.699999988079071          | 0.35447949171066284  |
| 0.800000011920929          | 0.2698955237865448   |
| 0.9000000357627869         | 0.2000139355659485   |
| 1.0                        | 0.15162856876850128  |
+----------------------------+----------------------+
fox_reported = Causal[4].item()
msnbc_reported = Causal[-4].item()
arrests_observed = Arrests[-2].item()

nobs_list = list(range(10))

foxviewer_policeperformance__arrests = list()
msnbcviewer_policeperformance__arrests = list()
for nobs_ in nobs_list:
    res_fox = performanceInferred(fox_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata_fox = res_fox.aux.xarray
    foxviewer_policeperformance__arrests.append(xardata_fox.loc[arrests_observed].item())

    res_msnbc_viewer = performanceInferred(msnbc_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata_msnbc = res_msnbc_viewer.aux.xarray
    msnbcviewer_policeperformance__arrests.append(xardata_msnbc.loc[arrests_observed].item())

fig, ax = plt.subplots()
ax.plot(nobs_list, foxviewer_policeperformance__arrests, label='Fox')
ax.plot(nobs_list, msnbcviewer_policeperformance__arrests, label='MSNBC')
_ = ax.set_xlabel("number of newscasts viewed")
_ = ax.set_ylabel("inferred police performance")
_ = ax.set_title(f"arrests: {arrests_observed:0.3f}")
ax.legend()

fox_reported = Causal[4].item()
msnbc_reported = Causal[-4].item()
arrests_observed = Arrests[2].item()

nobs_list = list(range(10))

foxviewer_policeperformance__arrests = list()
msnbcviewer_policeperformance__arrests = list()
for nobs_ in nobs_list:
    res_fox = performanceInferred(fox_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata_fox = res_fox.aux.xarray
    foxviewer_policeperformance__arrests.append(xardata_fox.loc[arrests_observed].item())

    res_msnbc_viewer = performanceInferred(msnbc_reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
    xardata_msnbc = res_msnbc_viewer.aux.xarray
    msnbcviewer_policeperformance__arrests.append(xardata_msnbc.loc[arrests_observed].item())

fig, ax = plt.subplots()
ax.plot(nobs_list, foxviewer_policeperformance__arrests, label='Fox')
ax.plot(nobs_list, msnbcviewer_policeperformance__arrests, label='MSNBC')
_ = ax.set_xlabel("number of newscasts viewed")
_ = ax.set_ylabel("inferred police performance")
_ = ax.set_title(f"arrests: {arrests_observed:0.3f}")
ax.legend()


fig, ax = plt.subplots()
for reported in [Causal[4].item()]:
    for nobs_ in [0, 1, 2, 4, 6, 8]:
        res = performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
        xrd = res.aux.xarray
        ax.plot(xrd["_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")

_ = ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
fig.legend()

fig, ax = plt.subplots()
for reported in [Causal[-4].item()]:
    for nobs_ in [0, 1, 2, 4, 6, 8]:
        res = performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
        xrd = res.aux.xarray
        ax.plot(xrd["_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")

_ = ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
fig.legend()

fig, ax = plt.subplots()
for reported in jnp.linspace(-1, 1, 7, endpoint=True):
    for nobs_ in [3]:
        res = performanceInferred(reported, nobs_, print_table=False, return_aux=True, return_xarray=True)
        xrd = res.aux.xarray
        ax.plot(xrd["_arrests_observed"], xrd, label=f"reported: {reported:0.2f}, nObs: {nobs_+1}")

_ = ax.set_xlabel("Arrests Observed")
_ = ax.set_ylabel("Inferred Performance")
fig.legend()


%reset -f
import sys
import platform
import importlib.metadata

print("Python:", sys.version)
print("Platform:", platform.system(), platform.release())
print("Processor:", platform.processor())
print("Machine:", platform.machine())

print("\nPackages:")
for name, version in sorted(
    ((dist.metadata["Name"], dist.version) for dist in importlib.metadata.distributions()),
    key=lambda x: x[0].lower()  # Sort case-insensitively
):
    print(f"{name}=={version}")
Python: 3.13.2 (main, Feb  5 2025, 18:58:04) [Clang 19.1.6 ]
Platform: Darwin 23.6.0
Processor: arm
Machine: arm64

Packages:
annotated-types==0.7.0
anyio==4.8.0
appnope==0.1.4
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==3.0.0
async-lru==2.0.4
attrs==25.1.0
babel==2.17.0
beautifulsoup4==4.13.3
bleach==6.2.0
certifi==2025.1.31
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.1
click==8.1.8
comm==0.2.2
contourpy==1.3.1
cycler==0.12.1
debugpy==1.8.13
decorator==5.2.1
defusedxml==0.7.1
distlib==0.3.9
executing==2.2.0
fastjsonschema==2.21.1
filelock==3.17.0
fonttools==4.56.0
fqdn==1.5.1
h11==0.14.0
httpcore==1.0.7
httpx==0.28.1
identify==2.6.8
idna==3.10
importlib_metadata==8.6.1
ipykernel==6.29.5
ipython==9.0.1
ipython_pygments_lexers==1.1.1
ipywidgets==8.1.5
isoduration==20.11.0
jax==0.5.2
jaxlib==0.5.1
jedi==0.19.2
Jinja2==3.1.6
joblib==1.4.2
json5==0.10.0
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter-cache==1.0.1
jupyter-events==0.12.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.15.0
jupyter_server_terminals==0.5.3
jupyterlab==4.3.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.8
MarkupSafe==3.0.2
matplotlib==3.10.1
matplotlib-inline==0.1.7
memo-lang==1.1.0
mistune==3.1.2
ml_dtypes==0.5.1
nbclient==0.10.2
nbconvert==7.16.6
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook_shim==0.2.4
numpy==2.2.3
opt_einsum==3.4.0
optype==0.9.1
overrides==7.7.0
packaging==24.2
pandas==2.2.3
pandas-stubs==2.2.3.241126
pandocfilters==1.5.1
parso==0.8.4
pexpect==4.9.0
pillow==11.1.0
platformdirs==4.3.6
plotly==5.24.1
pre_commit==4.1.0
prometheus_client==0.21.1
prompt_toolkit==3.0.50
psutil==7.0.0
ptyprocess==0.7.0
pure_eval==0.2.3
pycparser==2.22
pydantic==2.10.6
pydantic_core==2.27.2
Pygments==2.19.1
pygraphviz==1.14
pyparsing==3.2.1
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==3.3.0
pytz==2025.1
PyYAML==6.0.2
pyzmq==26.2.1
referencing==0.36.2
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.23.1
ruff==0.9.10
scikit-learn==1.6.1
scipy==1.15.2
scipy-stubs==1.15.2.0
seaborn==0.13.2
Send2Trash==1.8.3
setuptools==75.8.2
six==1.17.0
sniffio==1.3.1
soupsieve==2.6
SQLAlchemy==2.0.38
stack-data==0.6.3
tabulate==0.9.0
tenacity==9.0.0
terminado==0.18.1
threadpoolctl==3.5.0
tinycss2==1.4.0
toml==0.10.2
tornado==6.4.2
tqdm==4.67.1
traitlets==5.14.3
types-python-dateutil==2.9.0.20241206
types-pytz==2025.1.0.20250204
typing_extensions==4.12.2
tzdata==2025.1
uri-template==1.3.0
urllib3==2.3.0
virtualenv==20.29.3
wcwidth==0.2.13
webcolors==24.11.1
webencodings==0.5.1
websocket-client==1.8.0
widgetsnbextension==4.0.13
xarray==2025.1.2
zipp==3.21.0