Skip to content

Product and Sum Types

Page Maps

graph LR
  family["Python Programming"]
  program["Python Functional Programming"]
  section["Algebraic Data Modelling Validation"]
  page["Product and Sum Types"]
  capstone["Capstone evidence"]

  family --> program --> section --> page
  page -.applies in.-> capstone
flowchart LR
  orient["Orient on the page map"] --> read["Read the main claim and examples"]
  read --> inspect["Inspect the related code, proof, or capstone surface"]
  inspect --> verify["Run or review the verification path"]
  verify --> apply["Apply the idea back to the module and capstone"]

Make one modelling question visible right away: is the concept “all of these fields together” or “exactly one of these variants”? Once you can ask that question clearly, a lot of accidental complexity in Python models stops looking inevitable.

Start With the Invalid-State Smell

You may know the data you care about and still encode it with booleans, nullable fields, and mutable classes that allow impossible combinations.

  • If several fields must all exist together, the model wants a product type.
  • If a value must be one variant or another, the model wants a tagged sum.
  • If a model can hold contradictory information at once, the shape is still too loose.

Core question
How do you replace ad-hoc dicts, mutable classes, and fragile inheritance with pure product and tagged sum types — guaranteeing exhaustive handling, immutability, and stable serialization in every pipeline stage?

This lesson introduces ADTs as the first modelling tool to reach for in this module:

  • use product types when the value is a fixed bundle of fields
  • use tagged sums when the value is one of several distinct cases
  • make those choices explicit enough that the type shape itself removes bad states

The motivating Chunk and error examples matter because they show a common failure pattern: even after adding better error flow, the underlying data shapes are still too weak to prevent contradictory or incomplete states.

The naïve (and extremely common) solution:

class ChunkState:
    def __init__(self, success=None, embedding=None, error=None):
        self.success = success
        self.embedding = embedding
        self.error = error

# somewhere deep in the code...
if state.success:
    index(state.embedding)   # oops, someone forgot to check error is not None

This is the core modelling smell to recognize instantly: boolean blindness and null soup.

The production solution: model every domain concept as either a product type (AND) or a tagged sum type (OR — exactly one variant).

ChunkState = Success | Failure

@dataclass(frozen=True, slots=True, kw_only=True)
class Success:
    kind: Literal["success"] = "success"
    embedding: tuple[float, ...]
    metadata: tuple[tuple[str, JSON], ...]

@dataclass(frozen=True, slots=True, kw_only=True)
class Failure:
    kind: Literal["failure"] = "failure"
    code: str
    msg: str
    attempt: int

Now the model itself tells you which cases exist, and later tools like match and assert_never can enforce that understanding mechanically.

Use this when incomplete state handling, silent None paths, and mutable model shapes let bad combinations slip through review.

Outcome 1. Every dict/class soup replaced with proper product and tagged sum types. 2. Exhaustiveness proved via mypy + assert_never. 3. Immutable, serialisable domain models that survive refactors without silent regressions.

Tiny Non-Domain Example – Shape ADT

from dataclasses import dataclass
from typing import Literal
from typing import assert_never

@dataclass(frozen=True, slots=True)
class Circle:
    kind: Literal["circle"] = "circle"
    radius: float

@dataclass(frozen=True, slots=True)
class Rectangle:
    kind: Literal["rectangle"] = "rectangle"
    width: float
    height: float

Shape = Circle | Rectangle

def area(s: Shape) -> float:
    match s:
        case Circle(radius=r):
            return 3.14159265359 * r * r
        case Rectangle(width=w, height=h):
            return w * h
    assert_never(s)  # mypy errors if you add Triangle and forget to handle it

Adding a new variant now forces the handling sites to be updated instead of silently drifting out of sync.

Why ADTs? (Three bullets every engineer should internalise)

  • Exhaustiveness: Adding a variant breaks every handler until you update it — no silent missing cases.
  • Immutability + Value semantics: Frozen + structural eq/hash → safe in sets/dicts, pure functions, cache keys (provided nested metadata structures are not mutated after construction).
  • Stable serialization: Explicit kind tag + deterministic field order (sorted tuple for metadata) → JSON round-trip without surprises.

1. Laws & Invariants (machine-checked)

Law Formal Statement Enforcement
Exhaustiveness Every match over a sum type must handle all variants (proved by assert_never) mypy --strict + tests
Immutability Dataclass fields cannot be reassigned. Nested containers in metadata are shared references and remain mutable (do not mutate them after chunk creation to preserve value semantics) test_adt_immutability (top-level)
Structural Equality x == y iff all fields equal (stable under dict key order via sorted tuples) test_chunk_metadata_order_independent
JSON Round-Trip from_dict(to_dict(x)) == x for all instances test_chunk_roundtrip, test_chunk_state_roundtrip

2. Decision Table – Which ADT Construction Do You Actually Use?

Data Shape Has Payload? Needs Tags? Recommended Construction
Simple record (AND of fields) Yes No @dataclass(frozen=True, slots=True, kw_only=True)
Simple enumeration (no data) No Yes class Status(Enum): PENDING = "pending" ...
Tagged variants with data (OR + payload) Yes Yes Union of tagged dataclasses with kind: Literal["..."]
Deeply nested tree Yes Yes Recursive Union of tagged dataclasses

Never use mutable classes or bare dicts for domain data.

3. Public API (fp/core.py – mypy --strict clean)

from __future__ import annotations

from dataclasses import dataclass
from typing import Iterable, Literal, Mapping, Sequence, Tuple, TypeAlias

JSONPrimitive: TypeAlias = str | int | float | bool | None
JSON: TypeAlias = JSONPrimitive | Mapping[str, "JSON"] | Sequence["JSON"]
Path = Tuple[int, ...]

def _freeze_metadata(m: Mapping[str, JSON]) -> Tuple[Tuple[str, JSON], ...]:
    # sort by key only – values may be heterogeneous JSON
    return tuple(sorted(m.items()))

@dataclass(frozen=True, slots=True, kw_only=True)
class Chunk:
    text: str
    path: Path
    metadata: Tuple[Tuple[str, JSON], ...]   # top-level frozen, order-independent
    version: Literal[1] = 1

def make_chunk(
    *,
    text: str,
    path: Path,
    metadata: Mapping[str, JSON],
) -> Chunk:
    return Chunk(text=text, path=path, metadata=_freeze_metadata(metadata))

def chunk_to_dict(c: Chunk) -> dict[str, JSON]:
    return {
        "version": c.version,
        "text": c.text,
        "path": list(c.path),
        "metadata": dict(c.metadata),
    }

def chunk_from_dict(d: Mapping[str, JSON]) -> Chunk:
    if d.get("version") != 1:
        raise ValueError("unsupported version")
    return make_chunk(
        text=str(d["text"]),
        path=tuple(int(i) for i in d["path"]),
        metadata=dict(d["metadata"]),
    )

# Success / Failure sum type for embedding outcomes
@dataclass(frozen=True, slots=True, kw_only=True)
class Success:
    kind: Literal["success"] = "success"
    embedding: Tuple[float, ...]
    metadata: Tuple[Tuple[str, JSON], ...]

@dataclass(frozen=True, slots=True, kw_only=True)
class Failure:
    kind: Literal["failure"] = "failure"
    code: str
    msg: str
    attempt: int

ChunkState = Success | Failure

def success(
    *,
    embedding: Iterable[float],
    metadata: Mapping[str, JSON],
) -> Success:
    return Success(
        embedding=tuple(float(x) for x in embedding),
        metadata=_freeze_metadata(metadata),
    )

def failure(*, code: str, msg: str, attempt: int) -> Failure:
    return Failure(code=code, msg=msg, attempt=attempt)

4. Reference Implementations (continued)

4.1 Tree Sum Type (recursive tagged union)

from typing import assert_never

@dataclass(frozen=True, slots=True, kw_only=True)
class TextNode:
    kind: Literal["text"] = "text"
    content: str

@dataclass(frozen=True, slots=True, kw_only=True)
class SectionNode:
    kind: Literal["section"] = "section"
    title: str
    children: Tuple["Node", ...]

@dataclass(frozen=True, slots=True, kw_only=True)
class ListNode:
    kind: Literal["list"] = "list"
    items: Tuple["Node", ...]

Node = TextNode | SectionNode | ListNode

4.2 Exhaustive Pattern Matching

def node_depth(n: Node) -> int:
    match n:
        case TextNode():
            return 0
        case SectionNode(children=children):
            return 1 + max((node_depth(c) for c in children), default=0)
        case ListNode(items=items):
            return 1 + max((node_depth(i) for i in items), default=0)
    assert_never(n)  # mypy errors if you add a new variant

4.3 JSON Round-Trip for Tagged Sum

def chunk_state_to_dict(state: ChunkState) -> dict[str, JSON]:
    base = {"kind": state.kind, "version": 1}
    if isinstance(state, Success):
        return base | {
            "embedding": list(state.embedding),
            "metadata": dict(state.metadata),
        }
    else:  # Failure
        return base | {
            "code": state.code,
            "msg": state.msg,
            "attempt": state.attempt,
        }

def chunk_state_from_dict(d: Mapping[str, JSON]) -> ChunkState:
    if d.get("version") != 1:
        raise ValueError("unsupported version")
    kind = d["kind"]
    if kind == "success":
        return success(
            embedding=d["embedding"],      # type: ignore[arg-type]
            metadata=dict(d["metadata"]),
        )
    if kind == "failure":
        return failure(
            code=d["code"],                # type: ignore[arg-type]
            msg=d["msg"],
            attempt=d["attempt"],          # type: ignore[arg-type]
        )
    raise ValueError(f"unknown kind {kind}")

4.4 Big-O & Allocation Guarantees

Construction Time Heap Notes
dataclass creation O(1) O(#fields) slots=True → no dict
Tagged union match O(1) O(1) Exhaustive via assert_never
JSON round-trip O(N) O(N) Stable via sorted tuples

4.5 Anti-Patterns & Immediate Fixes

Anti-Pattern Symptom Fix
Mutable domain classes Accidental mutation frozen=True, slots=True, kw_only=True
Untagged Union or dict variants Silent missing cases kind: Literal + assert_never
Dict for metadata Unstable equality/serialization Sorted tuple of tuples
Inheritance for variants Fragile, hard to exhaust Tagged union of dataclasses

5. Property-Based Proofs (capstone/tests/test_module_05_c01.py)

import dataclasses
from typing import assert_never

import pytest
from hypothesis import given
import hypothesis.strategies as st

# ... imports of your ADT types ...

@given(text=st.text(), path=st.lists(st.integers(), max_size=10).map(tuple),
       meta=st.dictionaries(st.text(), st.integers() | st.lists(st.integers())))
def test_chunk_immutability(text, path, meta):
    chunk = make_chunk(text=text, path=path, metadata=meta)
    with pytest.raises(dataclasses.FrozenInstanceError):
        chunk.text = "mutated"

@given(meta=st.dictionaries(st.text(), st.integers()))
def test_chunk_metadata_order_independent(meta):
    c1 = make_chunk(text="t", path=(), metadata=meta)
    c2 = make_chunk(text="t", path=(), metadata=dict(reversed(list(meta.items()))))
    assert c1 == c2
    assert hash(c1) == hash(c2)

@given(chunk=st.builds(make_chunk,
                      text=st.text(),
                      path=st.lists(st.integers(), max_size=10).map(tuple),
                      metadata=st.dictionaries(st.text(), st.integers() | st.none())))
def test_chunk_roundtrip(chunk):
    j = chunk_to_dict(chunk)
    reloaded = chunk_from_dict(j)
    assert chunk == reloaded

@given(succ=st.builds(success,
                     embedding=st.lists(st.floats(allow_nan=False, allow_infinity=False), max_size=10),
                     metadata=st.dictionaries(st.text(), st.integers())),
       fail=st.builds(failure, code=st.text(min_size=1), msg=st.text(), attempt=st.integers(min_value=1)))
def test_chunk_state_roundtrip(succ, fail):
    for state in (succ, fail):
        j = chunk_state_to_dict(state)
        reloaded = chunk_state_from_dict(j)
        assert state == reloaded

@given(node=st.recursive(
    st.builds(TextNode, content=st.text()),
    lambda children: st.one_of(
        st.builds(SectionNode, title=st.text(), children=children),
        st.builds(ListNode, items=children),
    ),
    max_leaves=20,
))
def test_node_exhaustive_match(node):
    def dummy(n: Node) -> int:
        match n:
            case TextNode():   return 0
            case SectionNode(): return 1
            case ListNode():    return 2
        assert_never(n)
    dummy(node)

6. Pre-Core Quiz

  1. Product type → AND of fields (dataclass)
  2. Tagged sum type → OR of variants with payloads
  3. frozen=True, slots=True → Immutability + efficiency
  4. assert_never → Exhaustiveness proof
  5. Why sorted tuple for metadata? → Stable equality & hash (independent of dict key order)

7. Post-Core Exercise

  1. Model your current chunk embedding state as a tagged sum type → add assert_never to every handler.
  2. Refactor one dict-based structure to frozen dataclass → test JSON round-trip + order-independent equality.
  3. Add a new variant to an existing sum type → verify mypy errors in all match sites.
  4. Replace a mutable class in your codebase with a frozen dataclass → measure memory improvement (slots!).

Continue with: Domain State ADTs

You now model every piece of domain data as pure, immutable, exhaustively-handled values — eliminating vast classes of bugs before they happen. The rest of Module 5 is about composing these ADTs into powerful abstractions (functors, applicatives, monoids).