Skip to content

Folds and Reductions

Page Maps

graph LR
  family["Python Programming"]
  program["Python Functional Programming"]
  section["Streaming Resilience Failure Handling"]
  page["Folds and Reductions"]
  capstone["Capstone evidence"]

  family --> program --> section --> page
  page -.applies in.-> capstone
flowchart LR
  orient["Orient on the page map"] --> read["Read the main claim and examples"]
  read --> inspect["Inspect the related code, proof, or capstone surface"]
  inspect --> verify["Run or review the verification path"]
  verify --> apply["Apply the idea back to the module and capstone"]

This lesson makes folds feel practical before they feel abstract. Treat a fold as the answer to a repeated-work problem: one traversal, one accumulator story, one place to reason about what is combined and when.

Start With the Repeated Traversal Smell

Once you can traverse a tree safely, the next temptation is to walk it again and again for each statistic you want. That is the waste this lesson surfaces.

  • If count, max depth, and text length are computed in separate passes, the design is repeating the same traversal logic.
  • If the accumulator shape is unclear, the fold reads like machinery instead of a summary of what the code is collecting.
  • If a streaming scan changes order or work bounds, the fold implementation has stopped matching the traversal contract from the previous lesson.

Core question:
How do you replace any structural-recursive aggregation with an iterative fold (catamorphism) that is stack-safe, fully lazy when needed, and capable of fusing arbitrary numbers of independent aggregations into a single O(N) traversal?

This lesson introduces folds as the disciplined aggregation layer on top of safe traversal:

  • use one traversal to compute several related summaries
  • make the accumulator explicit so you can see exactly what state is carried
  • preserve preorder and bounded-work behavior when moving from a simple fold to a streaming scan

The motivating tree statistics are useful because they are easy to explain and immediately show why fusion matters.

The naïve recursive solution is beautiful and obvious:

def recursive_stats(tree: TreeDoc) -> tuple[int, int, int]:
    count = 1
    length = len(tree.node.text)
    max_d = 0
    for child in tree.children:
        c, l, md = recursive_stats(child)
        count += c
        length += l
        max_d = max(max_d, md + 1)
    return count, length, max_d

It works perfectly as a specification of the aggregate… until the tree is 2000 levels deep and the recursion strategy itself becomes the weak point.

The production solution must still tell the same aggregation story, but it has to do so in a way that is iterative, fused, and optionally streaming.

That is what a fold gives us in this module: one explicit reduction over the tree where the recursion strategy is safe and the aggregation logic remains reviewable.

Use this when you routinely aggregate statistics over tree/document/graph structures and refuse to ship code that can RecursionError on pathological but legal inputs.

Outcome:
1. You will replace any recursive aggregation with an iterative fold that is formally terminating and stack-safe.
2. You will fuse arbitrary numbers of independent aggregations into a single traversal using immutable tuple accumulators.
3. You will ship streaming reductions (scan_tree) that are truly lazy and short-circuitable.

This section formalises exactly what you should be able to defend: termination, stack-safety, preorder consistency, fusion of aggregates, and bounded work for scans.


Concrete Motivating Example

Same deep Markdown-derived tree from M04C01:

graph TD
  root["Root (title)<br/>50 chars, depth 0"]
  s1["Section 1<br/>30 chars, depth 1"]
  s11["Subsection 1.1<br/>20 chars, depth 2"]
  s12["Subsection 1.2<br/>25 chars, depth 2"]
  s2["Section 2<br/>35 chars, depth 1"]
  s21["Subsection 2.1<br/>40 chars, depth 2"]
  leaf["Deep leaf<br/>10 chars, depth 2002"]
  root --> s1
  root --> s2
  s1 --> s11
  s1 --> s12
  s2 --> s21 --> leaf

Desired aggregates (computed in one pass):

total_nodes       = 2004
total_text_length = 85_050
max_depth         = 2002

We want all three numbers, plus optionally a running total after each node (for progress bars, early termination, etc.).


1. Laws & Invariants (machine-checked where possible)

All laws assume finite, acyclic TreeDoc inputs (always non-empty; root node exists).

Law Formal Statement Enforcement
Termination & Stack-Safety Completes in O(N) steps with O(1) call-stack frames for any finite acyclic tree. Formal proof via explicit stack + Hypothesis on 5000-node chains + CI recursion-limit guard.
Equivalence fold_tree(t, seed, f) == recursive_fold(t, seed, f) for all t (identical result and preorder application). Hypothesis test_fold_vs_recursive_equivalence.
Fusion Fused tuple fold equals separate folds: (count, length, max_d) == (fold_count(t), fold_len(t), fold_max_d(t)). Hypothesis test_fusion_equivalence.
Bounded-Work (scan_tree) Consuming first k partial accumulators visits exactly k nodes. Instrumented property test_scan_bounded_work.
Order Law Combiner applied in strict preorder (identical sequence to flatten(t) from M04C01). Property test test_fold_preorder_matches_flatten.

These laws turn “fold” from a pattern into a verifiable contract.


2. Decision Table – Which Fold Do You Actually Use?

Need Streaming Partials? Multiple Values? Recommended Variant
Single aggregate (tree) No No fold_tree or fold_tree_no_path
Multiple aggregates (tree) No Yes fold_tree_buffered with tuple accumulator (fused)
Running totals (tree) Yes No/Yes scan_tree (optionally with tuple accumulator)
Linear (list/iterator) No/Yes No linear_reduce / linear_accumulate

Never use recursive aggregation in library code.
Never run multiple separate folds when a single fused tuple fold gives the same result in one pass.


3. Public API Surface (end-of-Module-04 refactor note)

Refactor note: tree folds/scans live in funcpipe_rag.tree (capstone/src/funcpipe_rag/tree/folds.py).
funcpipe_rag.api.core re-exports the same names as a stable façade for the course modules.

from funcpipe_rag.api.core import (
    fold_count_length_maxdepth,
    fold_tree,
    fold_tree_buffered,
    fold_tree_no_path,
    linear_accumulate,
    linear_reduce,
    scan_count_length_maxdepth,
    scan_tree,
)

4. Reference Implementations

4.1 Recursive Specification (Didactic only)

def recursive_fold(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int, Path], R],
    *,
    depth: int = 0,
    path: Path = (),
) -> R:
    acc = combiner(seed, tree, depth, path)
    for i, child in enumerate(tree.children):
        acc = recursive_fold(child, acc, combiner, depth=depth + 1, path=path + (i,))
    return acc

4.2 Simple Explicit-Stack Fold (Readable reference)

def fold_tree(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int, Path], R],
) -> R:
    acc = seed
    stack: deque[tuple[TreeDoc, int, Path, int]] = deque([(tree, 0, (), 0)])
    while stack:
        node, depth, path, child_idx = stack.pop()
        if child_idx == 0:
            acc = combiner(acc, node, depth, path)
        if child_idx < len(node.children):
            stack.append((node, depth, path, child_idx + 1))
            child = node.children[child_idx]
            stack.append((child, depth + 1, path + (child_idx,), 0))
    return acc

4.3 Production Winner – Buffered-Path Fold (zero extra tuple allocation)

def fold_tree_buffered(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int, Path], R],
) -> R:
    """
    Same semantics as fold_tree but maintains the path using a single mutable list
    (no tuples on the traversal stack). One tuple per node is still created when
    calling combiner (intrinsic to passing the path).
    """
    acc = seed
    stack: deque[tuple[TreeDoc, int, int | None]] = deque([(tree, 0, None)])
    path: list[int] = []
    last_depth = 0

    while stack:
        node, depth, sib_idx = stack.pop()

        # Maintain mutable path prefix (identical logic to iter_flatten_buffered)
        if depth < last_depth:
            del path[depth:]
        if sib_idx is not None:
            if depth > len(path):
                path.append(sib_idx)
            else:
                path[depth-1] = sib_idx
        last_depth = depth

        acc = combiner(acc, node, depth, tuple(path[:depth]))

        for i in range(len(node.children)-1, -1, -1):
            stack.append((node.children[i], depth + 1, i))

    return acc

4.4 Optimised Fold Without Path (when path unused)

def fold_tree_no_path(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int], R],
) -> R:
    """Third parameter is depth."""
    acc = seed
    stack: deque[tuple[TreeDoc, int, int]] = deque([(tree, 0, 0)])
    while stack:
        node, depth, child_idx = stack.pop()
        if child_idx == 0:
            acc = combiner(acc, node, depth)
        if child_idx < len(node.children):
            stack.append((node, depth, child_idx + 1))
            stack.append((node.children[child_idx], depth + 1, 0))
    return acc

4.5 Streaming Scan (Running Totals – Truly Lazy)

def scan_tree(
    tree: TreeDoc,
    seed: R,
    combiner: Callable[[R, TreeDoc, int, Path], R],
) -> Iterator[R]:
    """Yield running accumulator after each node in preorder – O(k) work for first k yields."""
    acc = seed
    stack: deque[tuple[TreeDoc, int, Path, int]] = deque([(tree, 0, (), 0)])
    while stack:
        node, depth, path, child_idx = stack.pop()
        if child_idx == 0:
            acc = combiner(acc, node, depth, path)
            yield acc
        if child_idx < len(node.children):
            stack.append((node, depth, path, child_idx + 1))
            child = node.children[child_idx]
            stack.append((child, depth + 1, path + (child_idx,), 0))

Note: linear_accumulate (via itertools.accumulate) yields the initial seed as the first value; scan_tree yields only post-node accumulators (no initial seed yield).

4.6 Fused Multi-Value Example (Count + Length + Max Depth)

def fold_count_length_maxdepth(tree: TreeDoc) -> Tuple[int, int, int]:
    def step(acc: Tuple[int, int, int], tree: TreeDoc, depth: int, path: Path) -> Tuple[int, int, int]:
        count, length, max_d = acc
        return (
            count + 1,
            length + len(tree.node.text),
            max(max_d, depth)
        )
    return fold_tree_buffered(tree, (0, 0, 0), step)

5. Property-Based Proofs (capstone/tests/test_tree_folds.py)

@given(tree=tree_strategy())
def test_fold_vs_recursive_equivalence(tree):
    rec = recursive_fold(tree, (0, 0, 0), step_count_len_maxd)
    buf = fold_tree_buffered(tree, (0, 0, 0), step_count_len_maxd)
    assert rec == buf

@given(tree=tree_strategy())
def test_fusion_equivalence(tree):
    fused = fold_count_length_maxdepth(tree)
    count = fold_tree_no_path(tree, 0, lambda a, n, d: a + 1)
    length = fold_tree_no_path(tree, 0, lambda a, n, d: a + len(n.node.text))
    max_d = fold_tree_no_path(tree, 0, lambda a, n, d: max(a, d))
    assert fused == (count, length, max_d)

@given(tree=tree_strategy())
def test_fold_preorder_matches_flatten(tree):
    from funcpipe_rag.api.core import flatten
    order_via_fold: list[Path] = []
    fold_tree(tree, None, lambda _, n, d, p: order_via_fold.append(p))
    order_via_flatten = [c.metadata["path"] for c in flatten(tree)]
    assert order_via_fold == order_via_flatten

@given(tree=tree_strategy())
def test_fold_buffered_order_matches_simple(tree):
    order_simple: list[Path] = []
    order_buf: list[Path] = []
    fold_tree(tree, None, lambda _, n, d, p: order_simple.append(p))
    fold_tree_buffered(tree, None, lambda _, n, d, p: order_buf.append(p))
    assert order_simple == order_buf

6. Big-O & Allocation Guarantees (peak auxiliary memory)

Variant Time Call-stack Peak auxiliary heap Total allocations
fold_tree / scan_tree O(N) O(1) O(depth) O(N×depth) paths
fold_tree_buffered O(N) O(1) O(depth) O(N) paths (only on combine)
fold_tree_no_path O(N) O(1) O(depth) Zero paths

Result metadata (paths when used) is intrinsic; auxiliary overhead is only the explicit stack + one mutable path list.


7. Anti-Patterns & Immediate Fixes

Anti-Pattern Symptom Fix
Recursive aggregation in library code RecursionError on deep trees Replace with fold_tree_buffered
Separate folds for related stats 3–10× slower on large trees Fuse with tuple accumulator
Mutable accumulator Aliasing / nondeterminism Use immutable tuples
String concatenation in combiner Quadratic time Count lengths, join once at end

8. Pre-Core Quiz

  1. Safe recursion over trees? → Iterative fold with explicit stack
  2. Compute 5 independent stats? → One fused tuple fold
  3. Running totals over tree? → scan_tree
  4. Zero extra path allocation on deep chains? → fold_tree_buffered
  5. Equivalence guarantee? → Hypothesis vs recursive spec + order test

9. Post-Core Exercise

  1. Implement recursive max-depth → replace with fold → add equivalence + order property.
  2. Fuse count + text length + max depth + set of all doc_ids in one fold.
  3. Add scan_tree progress reporting to the RAG pipeline (yield running chunk count).
  4. Find any multi-pass aggregation in your codebase → fuse into one fold → measure speedup.

Continue with: Memoization

You now own the universal pattern for safe aggregation over any algebraic data type. Everything else in Module 4 is just specialising this fold.