Source code for cedne.core.fold_policy

"""Fold policies: how to aggregate constituent attributes during a fold.

Folding (``Network.fold_network`` / ``Network.contract_neurons`` /
``Network.contract_connections``) collapses N entities into one supernode
or supernedge. Each entity carries attributes (type, recording, ligands,
receptor dicts, etc.); the question is *how the supernode should
represent that attribute*. Historically this was handled ad-hoc inside
``contract_connections`` for a small set of known fields (set-union for
ligands, sum for weight, etc.) — fine for two papers' worth of fields,
brittle as the dataset menu grows.

This module formalizes the choice as a per-attribute ``FoldPolicy``:

    FoldPolicy(name="recording", kind="timeseries", aggregator="timeseries_mean")

A ``FoldPolicySet`` is a mapping of attribute-name to policy that gets
applied during a single fold and **stored on the resulting network** so
the choice is reproducible. The default aggregator for an unregistered
attribute is ``drop`` — silent semantic drift on un-thought-about fields
is worse than missing data.

Aggregator menu (see ``apply_policy`` for semantics):

    Scalar (int/float):     mean | median | max | min | mode
    Time series:            timeseries_mean | timeseries_median | timeseries_max
    List:                   mode_count | keep_all | set_union
    Dict:                   dict_union | dict_intersection | dict_collated
    Categorical:            mode | keep_all
    Any kind:               drop  (return None — attribute not copied)

Time-series aggregators truncate to the shortest constituent's length
(common defensive choice for recording data with motion-rejected frames
or differing trial lengths).
"""

from __future__ import annotations

import statistics
from collections import Counter
from dataclasses import dataclass, field
from typing import Any, Dict, Iterable, List, Literal, Optional

import numpy as np


__author__ = "Sahil Moza"
__date__ = "2026-05-12"
__license__ = "MIT"


# Aggregator names per kind. Keeping the kinds open as plain strings lets
# downstream code register custom kinds without editing this module.
ScalarAgg = Literal["sum", "mean", "median", "max", "min", "mode"]
TimeseriesAgg = Literal["timeseries_mean", "timeseries_median", "timeseries_max"]
ListAgg = Literal["mode_count", "keep_all", "set_union"]
DictAgg = Literal["dict_union", "dict_intersection", "dict_collated"]
CategoricalAgg = Literal["mode", "keep_all"]

# Sentinel meaning "don't copy this attribute to the supernode". Works on
# every kind; equivalent to the historical default of "not registered =
# not folded".
DROP: str = "drop"

ALL_AGGREGATORS = frozenset(
    [
        "sum",
        "mean",
        "median",
        "max",
        "min",
        "mode",
        "timeseries_mean",
        "timeseries_median",
        "timeseries_max",
        "mode_count",
        "keep_all",
        "set_union",
        "dict_union",
        "dict_intersection",
        "dict_collated",
        DROP,
    ]
)

# Canonical menu order (UI-facing): drives both the picker layout and the
# ``valid_aggregators(kind)`` public helper. _VALID_AGGREGATORS_BY_KIND is
# the membership-check companion (frozenset for O(1) lookup in
# ``__post_init__``) and is derived from this tuple-of-tuples so the two
# can never drift.
_AGGREGATOR_MENUS: Dict[str, tuple] = {
    "scalar": ("sum", "mean", "median", "max", "min", "mode", DROP),
    "timeseries": (
        "timeseries_mean",
        "timeseries_median",
        "timeseries_max",
        DROP,
    ),
    "list": ("mode_count", "keep_all", "set_union", DROP),
    "dict": ("dict_union", "dict_intersection", "dict_collated", DROP),
    "categorical": ("mode", "keep_all", "same_or_merged", DROP),
}

_VALID_AGGREGATORS_BY_KIND: Dict[str, frozenset] = {
    kind: frozenset(menu) for kind, menu in _AGGREGATOR_MENUS.items()
}


def valid_aggregators(kind: str) -> list:
    """Aggregator names valid for ``kind``, in canonical UI menu order.

    Public discovery helper for callers (e.g. backend endpoints, picker
    UIs) that need the per-kind menu without depending on the private
    ``_VALID_AGGREGATORS_BY_KIND`` table. Raises KeyError on unknown kind.
    """
    return list(_AGGREGATOR_MENUS[kind])


def valid_kinds() -> list:
    """All kinds the policy system understands, in declaration order."""
    return list(_AGGREGATOR_MENUS.keys())


# Sentinel returned by the ``same_or_merged`` categorical aggregator when
# constituents disagree on a categorical attribute. Mirrored to
# cedne/core/neuron.py's MERGED_TYPE (same string, defined twice to
# avoid a fold_policy → neuron import cycle).
SAME_OR_MERGED_SENTINEL: str = "merged"


[docs] @dataclass(frozen=True) class FoldPolicy: """How to aggregate one named attribute across constituents. Attributes: name: The attribute key (e.g. ``"recording"`` for time series, ``"type"`` for categorical, ``"ligands"`` for list-valued connection metadata). Used by callers as the lookup key against the live attribute on each constituent. kind: The shape of the attribute. Determines which aggregators are valid; ``apply_policy`` raises if a kind/aggregator combination doesn't make sense. aggregator: How to combine N constituent values into one. Default ``"drop"`` means the attribute isn't copied to the supernode (safe default — no silent merge). """ name: str kind: Literal["scalar", "timeseries", "list", "dict", "categorical"] aggregator: str = DROP def __post_init__(self): valid = _VALID_AGGREGATORS_BY_KIND.get(self.kind) if valid is None: raise ValueError( f"Unknown FoldPolicy kind: {self.kind!r}. " f"Expected one of {sorted(_VALID_AGGREGATORS_BY_KIND)}." ) if self.aggregator not in valid: raise ValueError( f"Aggregator {self.aggregator!r} is not valid for kind " f"{self.kind!r}. Valid: {sorted(valid)}." )
[docs] def to_dict(self) -> Dict[str, str]: return { "name": self.name, "kind": self.kind, "aggregator": self.aggregator, }
[docs] @classmethod def from_dict(cls, d: Dict[str, str]) -> "FoldPolicy": return cls(name=d["name"], kind=d["kind"], aggregator=d.get("aggregator", DROP))
[docs] @dataclass class FoldPolicySet: """A bundle of per-attribute policies + a default for unregistered attrs. Attached to the resulting network after a fold so the choices that produced it are part of the provenance record (alongside the constituent subgraphs already captured by ``fold_network``). """ policies: Dict[str, FoldPolicy] = field(default_factory=dict) # Fallback for any attribute the caller didn't explicitly register. # ``drop`` is the safe default — see module docstring. default_aggregator: str = DROP
[docs] def add(self, policy: FoldPolicy) -> "FoldPolicySet": self.policies[policy.name] = policy return self
[docs] def get(self, name: str, kind: str) -> FoldPolicy: """Return the policy for ``name`` or a drop policy of the given kind.""" if name in self.policies: return self.policies[name] return FoldPolicy(name=name, kind=kind, aggregator=self.default_aggregator)
[docs] def to_dict(self) -> Dict[str, Any]: return { "policies": {n: p.to_dict() for n, p in self.policies.items()}, "default_aggregator": self.default_aggregator, }
[docs] @classmethod def from_dict(cls, d: Dict[str, Any]) -> "FoldPolicySet": return cls( policies={ n: FoldPolicy.from_dict(p) for n, p in (d.get("policies") or {}).items() }, default_aggregator=d.get("default_aggregator", DROP), )
# --------------------------------------------------------------------------- # Aggregator implementations # --------------------------------------------------------------------------- def _coerce_numeric(values: Iterable[Any]) -> List[float]: """Pull numeric-looking values out of a heterogeneous list.""" out: List[float] = [] for v in values: if v is None: continue try: f = float(v) except (TypeError, ValueError): continue if np.isnan(f) or np.isinf(f): continue out.append(f) return out def _agg_scalar(values: Iterable[Any], op: str) -> Optional[float]: arr = _coerce_numeric(values) if not arr: # ``sum`` of nothing is conventionally 0; the historical # contract_connections behavior initialized weight=0 even for # empty constituent lists, so we match that here. return 0.0 if op == "sum" else None if op == "sum": return float(np.sum(arr)) if op == "mean": return float(np.mean(arr)) if op == "median": return float(np.median(arr)) if op == "max": return float(np.max(arr)) if op == "min": return float(np.min(arr)) if op == "mode": try: return float(statistics.mode(arr)) except statistics.StatisticsError: # Ties — pick the first multimode value, deterministic. return float(statistics.multimode(arr)[0]) raise ValueError(f"Unknown scalar aggregator: {op!r}") def _agg_timeseries(values: Iterable[Any], op: str): """Aggregate constituent time-series, truncating to the shortest length.""" arrays = [] for v in values: if v is None: continue arr = np.asarray(v) if arr.ndim == 0 or arr.size == 0: continue arrays.append(arr) if not arrays: return None min_len = int(min(a.shape[0] for a in arrays)) stacked = np.stack([a[:min_len] for a in arrays]) # (N, min_len) if op == "timeseries_mean": return stacked.mean(axis=0) if op == "timeseries_median": return np.median(stacked, axis=0) if op == "timeseries_max": return stacked.max(axis=0) raise ValueError(f"Unknown timeseries aggregator: {op!r}") def _agg_categorical(values: Iterable[Any], op: str): """Combine categorical values (e.g. neuron type). Always preserves identity.""" arr = [v for v in values if v is not None and v != ""] if not arr: return None if op == "mode": try: return statistics.mode(arr) except statistics.StatisticsError: return statistics.multimode(arr)[0] if op == "keep_all": # preserve first-occurrence order seen = set() out = [] for v in arr: key = v if isinstance(v, (str, int, float, bool)) else repr(v) if key not in seen: seen.add(key) out.append(v) return out if op == "same_or_merged": # The historical neuron type-merge policy: if every constituent # agrees on the value, preserve it; if they disagree, return the # SAME_OR_MERGED_SENTINEL ("merged") so analyses that switch on # type can detect the merged case explicitly. Empty input # returns None — the caller is expected to skip setting the # attribute, matching contract_neurons's prior behavior of # leaving the attribute unchanged when no constituent had a # usable value. distinct = set(arr) if len(distinct) == 1: return next(iter(distinct)) return SAME_OR_MERGED_SENTINEL raise ValueError(f"Unknown categorical aggregator: {op!r}") def _set_key(x: Any) -> Any: """Hashable key for set-union dedup. Matches the legacy contract: lists are coerced to tuples (so ``[a, b]`` and ``(a, b)`` count as the same element); primitives pass through; anything else is repr-stringified so non-hashable containers still dedupe.""" if isinstance(x, list): return tuple(x) if isinstance(x, (str, int, float, bool)) or x is None: return x try: hash(x) return x except TypeError: return repr(x) def _agg_list(values: Iterable[Any], op: str): """Combine list-valued attributes (ligands, receptors, etc.).""" flat: List[Any] = [] for v in values: if v is None: continue if isinstance(v, (list, tuple, set)): flat.extend(v) else: flat.append(v) if not flat: return [] if op != "mode_count" else None if op == "keep_all": return list(flat) if op == "set_union": seen = set() out: List[Any] = [] for x in flat: key = _set_key(x) if key not in seen: seen.add(key) out.append(x) return out if op == "mode_count": return Counter(_hashable_only(flat)).most_common(1)[0][0] raise ValueError(f"Unknown list aggregator: {op!r}") def _hashable_only(seq: Iterable[Any]) -> List[Any]: """Counter requires hashable values; cast complex objects via repr.""" out = [] for x in seq: try: hash(x) out.append(x) except TypeError: out.append(repr(x)) return out def _agg_dict(values: Iterable[Any], op: str): """Combine dict-valued attributes. union — every key from every constituent. When the same key appears in multiple constituents, the *first observed* value wins. Output is ``{k: v}``, i.e. same shape as a single-constituent dict. This is the deterministic, non-arbitrary replacement for what contract_connections historically did via a manual ``if k not in merged_receptors`` loop. intersection — only keys present in *every* constituent dict. Value is the first observed value per key. Same shape as ``union``. collated — every key from every constituent. Value is a list of all observed values (richer form for callers that want to see every observation, possibly to compute summary statistics off-line). Output is ``{k: [v0, v1, ...]}``. """ dicts = [v for v in values if isinstance(v, dict)] if not dicts: return None if op == "dict_union": out: Dict[Any, Any] = {} for d in dicts: for k, v in d.items(): if k not in out: out[k] = v return out if op == "dict_intersection": common = set(dicts[0].keys()) for d in dicts[1:]: common &= set(d.keys()) out_int: Dict[Any, Any] = {} for k in common: for d in dicts: if k in d: out_int[k] = d[k] break return out_int if op == "dict_collated": out_col: Dict[Any, List[Any]] = {} for d in dicts: for k, v in d.items(): out_col.setdefault(k, []).append(v) return out_col raise ValueError(f"Unknown dict aggregator: {op!r}") # --------------------------------------------------------------------------- # Default policy sets for known dataset shapes # --------------------------------------------------------------------------- # # DEFAULT_CONNECTION_FOLD_POLICY captures the behavior that # ``Network.contract_connections`` had before Phase 2 — weights sum, # transmitter / ligand / receptor list-valued metadata set-union, and # receptor dicts merge with first-observed-value semantics. Phase 2 # refactors contract_connections to drive its merge logic through this # policy set instead of inlining the rules. Callers that pass a custom # policy set override this; callers that pass nothing get current # behavior. # # Receptors note: pre-Phase-2 contract_connections did "first key wins" # via ``if rk not in merged_receptors``. ``dict_union`` here means the # same thing — first observed value per key — chosen to keep parity. DEFAULT_CONNECTION_FOLD_POLICY = FoldPolicySet( policies={ "weight": FoldPolicy("weight", "scalar", "sum"), "ligands": FoldPolicy("ligands", "list", "set_union"), "neurotransmitters": FoldPolicy("neurotransmitters", "list", "set_union"), "putative_neurotrasmitter_receptors": FoldPolicy( "putative_neurotrasmitter_receptors", "list", "set_union" ), "receptors": FoldPolicy("receptors", "dict", "dict_union"), }, default_aggregator=DROP, ) # DEFAULT_NEURON_FOLD_POLICY captures the historical contract_neurons / # fold_network behavior for the three "bounded categorical label" # attributes that have always been merged via the all-same-or-MERGED- # sentinel rule (the MERGE_TRACK_ATTRS list in cedne/core/neuron.py). # Phase 2.2 refactors contract_neurons + the batch path's neuron merge # to drive that decision through this policy set; parity tests pin the # behavior on identical inputs. # # Numeric properties (degree, length, transcript counts) and time-series # properties (recordings) are NOT registered here — the historical code # didn't aggregate them onto the supernode at all; their values stay on # the constituent neurons (reachable via constituent_subgraph). Callers # that want a per-fold aggregate (e.g. timeseries_mean on recording) # should pass a custom FoldPolicySet that extends this one. DEFAULT_NEURON_FOLD_POLICY = FoldPolicySet( policies={ "type": FoldPolicy("type", "categorical", "same_or_merged"), "category": FoldPolicy("category", "categorical", "same_or_merged"), "modality": FoldPolicy("modality", "categorical", "same_or_merged"), }, default_aggregator=DROP, )
[docs] def apply_policy(policy: FoldPolicy, values: List[Any]) -> Any: """Apply a policy to a list of constituent values. Returns the merged value, or ``None`` for the drop sentinel / empty input. Never raises on empty/missing data — that's the caller's cue to skip copying the attribute to the supernode. """ if policy.aggregator == DROP: return None if policy.kind == "scalar": return _agg_scalar(values, policy.aggregator) if policy.kind == "timeseries": return _agg_timeseries(values, policy.aggregator) if policy.kind == "list": return _agg_list(values, policy.aggregator) if policy.kind == "dict": return _agg_dict(values, policy.aggregator) if policy.kind == "categorical": return _agg_categorical(values, policy.aggregator) raise ValueError(f"Unknown FoldPolicy kind: {policy.kind!r}")