Source code for cedne.core.network

"""
Graph-based nervous system representation for CeDNe.

This module defines the `NervousSystem` class, which models a complete
neural network using a subclass of `networkx.MultiDiGraph`. It serves as
the central container for neurons (`Neuron`), connections (`Connection`),
and associated metadata, and provides high-level methods for construction,
analysis, and manipulation of neural circuits.

Main components:

- `NervousSystem`: Inherits from `networkx.MultiDiGraph` and integrates neuron
  and connection management with experimental and structural logic.

Key functionality includes:
- Creating neurons and connections from raw data or other networks
- Managing and updating network state (including filters, subgraphs, folding)
- Supporting motif search, groupings, and topological export
- Generating subgraphs based on attribute filters or structural criteria
- Contracting neurons and connections to simplify network topology
- Interfacing with experimental metadata (`Worm`, `Trial`, etc.)

This module is central to most workflows in CeDNe, serving as the graph-theoretic
and biological representation of the nervous system.
"""

__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"


import copy
import pickle
import json
import numpy as np
import networkx as nx
from .connection import Connection, ChemicalSynapse, GapJunction, ConnectionGroup
from .history import record
from .neuron import Neuron, NeuronGroup, MERGED_TYPE, MERGE_TRACK_ATTRS
from .animal import Worm
from .source import Citable


[docs] class NervousSystem(nx.MultiDiGraph, Citable): """ This is the Nervous System class. This inherits from networkx.MultiDiGraph and is the main high level class for the nervous system.""" def __init__(self, worm: Worm = None, network: str = "Neutral", **kwargs) -> None: """ Initializes the NervousSystem object with the given worm and network. Args: worm (Worm, optional): The worm object associated with the nervous system. Defaults to None. network (str, optional): The network for the nervous system. Can be different conditions or network types. Defaults to "Neutral". """ nx.MultiDiGraph.__init__(self) Citable.__init__(self) # provides self.citations = {} self.worm = worm or Worm() self.name = network self.worm.networks[network] = self self.groups = {} self.neurons = NeuronGroup(self, group_name="all_neurons") # dictionary of all neurons in the nervous system self.connections = ConnectionGroup(self, group_name="all_connections") # dictionary of all connections in the nervous system self.visualization_metadata = {} self._filtered_nodes = set() self._filtered_edges = set() for key, value in kwargs.items(): self.set_property(key, value) def _parent_citables(self): """Walk citation resolution up to the owning Animal/Worm.""" return (self.worm,) if self.worm is not None else () @property def num_groups(self): """ Returns the current number of Neuron Groups for the Nervous System. """ return len(self.groups)
[docs] def set_property(self, key, value): """ Set a property of the nervous system. Args: key (str): The name of the property. value: The value of the property. Returns: None """ setattr(self, key, value)
[docs] def build_network(self, neuron_data, adj, label): """ Make a network with the neurons. Args: neurons: The file containing neuron information adj: The adjacency matrix label: The label for the network """ with open(neuron_data, "rb") as neuron_file: node_dict = pickle.load(neuron_file) node_labels, l1_list, l2_list, l3_list = ( node_dict.iloc[:, 0].to_list(), node_dict.iloc[:, 1].to_list(), node_dict.iloc[:, 2].to_list(), node_dict.iloc[:, 3].to_list(), ) self.create_neurons( node_labels, type=l1_list, category=l2_list, modality=l3_list ) self.setup_connections(adj, label)
[docs] def create_neurons(self, labels, **kwargs): """ Creates a set of Neuron objects based on the given labels, types, categories, modalities, and positions. Args: labels (list): A list of labels for the neurons. neuron_types (list, optional): A list of types for the neurons. Defaults to None. categories (list, optional): A list of categories for the neurons. Defaults to None. modalities (list, optional): A list of modalities for the neurons. Defaults to None. positions (dict, optional): A dictionary mapping labels to positions. Defaults to None. Returns: None """ network_args = {} for key, value in kwargs.items(): if isinstance(value, dict): if not all([k in labels for k in value.keys()]): raise ValueError( f"{key}: Dictionary keys must be one of neuron labels" ) network_args[key] = { lab: value[lab] if lab in value else None for lab in labels } elif ( isinstance(value, int) or isinstance(value, str) or isinstance(value, float) or isinstance(value, bool) ): network_args[key] = {lab: value for lab in labels} elif hasattr(value, "__len__"): if not len(value) == len(labels): raise ValueError(f"{key} must be same length as neuron labels") network_args[key] = {lab: val for (lab, val) in zip(labels, value)} else: raise NotImplementedError( f"Attribute setting not implemented for datatype {type(value)}." ) for label in labels: neuron_args = {} for key, value in network_args.items(): neuron_args[key] = value[label] Neuron(label, self, **neuron_args)
[docs] def remove_neurons(self, neurons): """Remove neurons from the network.""" for neuron in neurons: if neuron not in self.neurons: raise TypeError(f" {neuron} is not a valid neuron name.") else: self.remove_node(neuron) self.update_neurons()
[docs] def create_neurons_from(self, network, data=False): """ Creates a set of Neuron objects based on the given network. Args: network (Network): A Network object. data (bool, optional): A flag indicating whether to include data in the Neuron objects. Defaults to False. """ ## Check if network object is a NervousSystem object if not isinstance(network, NervousSystem): raise TypeError("The network object must be a NervousSystem object") if not data: for node in network.nodes: Neuron(node.name, self) else: for node, data in network.nodes(data=True): Neuron(node.name, self, **data)
[docs] def create_connections(self, connection_dict): """Creates a set of connections from a dictinary of connections with pre-post pairs as keys and data as values.""" for (pre, post), data in connection_dict.items(): if pre not in self.neurons or post not in self.neurons: raise TypeError( "Input dictionary must use neuron names for connection IDs" ) n1 = self.neurons[pre] n2 = self.neurons[post] if len(data): conn = Connection(n1, n2, **data) else: conn = Connection(n1, n2) self.connections.update({(n1, n2, conn.uid): conn})
[docs] def remove_connections(self, connections): """Remove connections from the network""" for connection in connections: if isinstance(connection, Connection): self.remove_edge(connection.pre, connection.post) elif isinstance(connection, tuple): if isinstance(connection[0], Neuron) and isinstance( connection[1], Neuron ): n1 = connection[0] n2 = connection[1] self.remove_edge(n1, n2) elif isinstance(connection[0], str) and isinstance(connection[1], str): if connection[0] in self.neurons and connection[1] in self.neurons: n1 = self.neurons[connection[0]] n2 = self.neurons[connection[1]] self.remove_edge(n1, n2) else: raise NameError( f"{connection[0]} and {connection[1]} not in the network" ) else: raise TypeError( "Connections must be either a list of tuples of neurons or neuron names, or a list of Connections." ) else: raise TypeError( "Connections must be either a list of tuples of neurons or neuron names, or a list of Connections." ) self.update_connections()
[docs] def remove_all_connections(self): """Remove all connections from the network""" self.remove_connections(self.connections)
[docs] def create_connections_from(self, network, data=False): """ Creates a set of Connection objects based on the given network. Args: network (Network): A Network object. data (bool, optional): A flag indicating whether to include data in the Connection objects.\ Defaults to False. """ ## Check if network object is a NervousSystem object if not isinstance(network, NervousSystem): raise TypeError("The network object must be a NervousSystem object") for u, v, k, edge_data in network.edges(keys=True, data=True): n1 = self.neurons[u.name] n2 = self.neurons[v.name] if not data: self.connections.update({(n1, n2, k): Connection(n1, n2, k)}) else: self.connections.update( {(n1, n2, k): Connection(n1, n2, k, **edge_data)} )
[docs] def update_neurons(self): """ Synchronizes the neurons dictionary with the network's nodes. This should only be needed if the network's nodes are modified directly. """ self.neurons.clear() for node in self.nodes: if node.name not in self.neurons: self.neurons[node.name] = node
[docs] def update_connections(self): """ Update the dictionary of connections. Need more precaution here. """ # print({connection_id: self.connections[connection_id] for connection_id in self.connections}) pop_conns = [] ## Smells good ;)! for connection_id in self.connections: if connection_id not in self.edges: pop_conns.append(connection_id) for pop_conn in pop_conns: if pop_conn in self.connections: self.connections.pop(pop_conn) for n in self.neurons: self.neurons[n].update_connections()
[docs] def update_network(self): """ Update the network by setting the network attribute of all connections to self. """ for node in self.nodes: node.network = self for _, c in self.connections.items(): c.network = self
[docs] def setup_connections( self, adjacency, connection_type, input_type="adjacency", **kwargs ): """ Set up connections between neurons based on the adjacency matrix and edge type. """ if input_type == "adjacency": for source_id, neighbors in adjacency.items(): for target_id, properties in neighbors.items(): if "weight" in properties: if properties["weight"] == 0: continue else: source_neuron = self.neurons[source_id] target_neuron = self.neurons[target_id] edge_weight = properties["weight"] else: source_neuron = self.neurons[source_id] target_neuron = self.neurons[target_id] edge_weight = 1 # edge_id = self.add_edge( # source_neuron, target_neuron, # weight=edge_weight, color='k', connection_type=connection_type # ) connection = Connection( # source_neuron, target_neuron, edge_id, connection_type, weight=edge_weight source_neuron, target_neuron, connection_type=connection_type, weight=edge_weight, ) self.connections[(source_neuron, target_neuron, connection.uid)] = ( connection ) elif input_type == "edge": source_neuron = self.neurons[adjacency["pre"]] target_neuron = self.neurons[adjacency["post"]] edge_weight = adjacency["weight"] # edge_id = self.add_edge( # source_neuron, target_neuron, # weight=edge_weight, color='k', connection_type=connection_type # ) # connection = Connection(source_neuron, target_neuron, edge_id, connection_type,\ # weight=edge_weight, **kwargs) connection = Connection( source_neuron, target_neuron, connection_type=connection_type, weight=edge_weight, **kwargs, ) self.connections[(source_neuron, target_neuron, connection.uid)] = ( connection ) else: raise NotImplementedError( "Not implemented for this input type. Try 'adjacency' or 'edge'." )
[docs] def setup_chemical_connections(self, chemical_adjacency, **kwargs): """ Set up chemical connections in the network based on the given adjacency dictionary. Parameters: chemical_adjacency (dict): A dictionary representing the adjacency of chemical synapses. The keys are source neurons and the values are dictionaries where the keys are target neurons and the values are dictionaries containing the connection data. Returns: None This function iterates over the `chemical_adjacency` dictionary and adds chemical synapse edges between source neurons and target neurons if the connection weight is greater than 0. It uses the `add_edge` method to add the edge to the network and creates a `Connection` object to store the connection details. The created connection is added to the `connections` dictionary using a tuple of the source neuron, target neuron, and edge key as the key. """ connection_type = "chemical-synapse" for source_neuron, target_neurons in chemical_adjacency.items(): for target_neuron, connection_data in target_neurons.items(): if connection_data["weight"] > 0: # edge_key = self.add_edge( # self.neurons[source_neuron], # self.neurons[target_neuron], # weight=connection_data['weight'], # color='orange', # edgeType=connection_type # ) connection = ChemicalSynapse( self.neurons[source_neuron], self.neurons[target_neuron], # edge_key, connection_type=connection_type, weight=connection_data["weight"], color="orange", **kwargs, ) self.connections[ ( self.neurons[source_neuron], self.neurons[target_neuron], connection.uid, ) ] = connection
# self.add_edges_from(e) # Add edge attributes here.
[docs] def setup_gap_junctions(self, gap_junction_adjacency): """ Set up gap junctions in the network based on the given adjacency dictionary. Parameters: gap_junction_adjacency (dict): A dictionary representing the adjacency of gap junctions. The keys are source neurons and the values are dictionaries where the keys are target neurons and the values are dictionaries containing the connection data. Returns: None This function iterates over the `gap_junction_adjacency` dictionary and adds gap junction edges between source neurons and target neurons if the connection weight is greater than 0. It uses the `add_edge` method to add the edge to the network and creates a `Connection` object to store the connection details. The created connection is added to the `connections` dictionary using a tuple of the source neuron, target neuron, and edge key as the key. Note: - The `add_edge` method is assumed to be defined in the class. - The `Connection` class is assumed to be defined in the class. - The `neurons` dictionary is assumed to be defined in the class. """ connection_type = "gap-junction" for source_neuron, target_neurons in gap_junction_adjacency.items(): for target_neuron, connection_data in target_neurons.items(): if connection_data["weight"] > 0: # edge_key = self.add_edge( # self.neurons[source_neuron], # self.neurons[target_neuron], # weight=connection_data['weight'], # color='gray', # connection_type=connection_type # ) connection = GapJunction( self.neurons[source_neuron], self.neurons[target_neuron], # edge_key, connection_type=connection_type, color="gray", weight=connection_data["weight"], ) self.connections[ ( self.neurons[source_neuron], self.neurons[target_neuron], connection.uid, ) ] = connection
[docs] def load_neuron_data(self, file, file_format="summary-xlsx"): """Standard formats to load data into the network"""
# pass
[docs] def load_connection_data(self, file, file_format="summary-xlsx"): """Standard formats to load data into the network"""
# pass
[docs] @record("subnetwork") def subnetwork( self, neuron_names=None, name=None, connections=None, as_view=False, data=True ): """ Generates a subgraph of the network based on the given list of neuron names. Args: neuron_names (List[str]): List of neuron names to include in the subgraph. connections (List[tuple]): List of connections to include in the subgraph. Returns: NervousSystem: A deep copy of the subgraph generated from the neuron_names or connections. The subgraph contains a dictionary of neurons with their names as keys. """ if not as_view: if data == True: graph_copy = self.copy(copy_type="deep_with_data", name=name) else: graph_copy = self.copy(copy_type="deep_without_data", name=name) assert not ( neuron_names and connections ), "Specify either neuron_names or connections, not both." if neuron_names is not None: missing = [n for n in neuron_names if n not in graph_copy.neurons] if missing: raise KeyError(f"Neurons not found in network: {missing}") selected_names = set(neuron_names) nodes_to_remove = [ node for node in list(graph_copy.nodes) if node.name not in selected_names ] graph_copy.remove_nodes_from(nodes_to_remove) subgraph = graph_copy # subgraph_nodes = [graph_copy.neurons[name] for name in neuron_names] # subgraph = graph_copy.subgraph(subgraph_nodes) # subgraph.connections = {key: value for key, value in graph_copy.connections.items()\ # if key[0] in subgraph_nodes and key[1] in subgraph_nodes} elif connections is not None: new_connections = [ ( graph_copy.neurons[conn[0].name], graph_copy.neurons[conn[1].name], conn[2], ) for conn in connections ] new_connections = [ graph_copy.connections[key]._id for key in new_connections ] subnet = graph_copy.edge_subgraph( new_connections ) # That will put it through custom copy again. subgraph = NervousSystem( self.worm, network=name or self.name + "_subnetwork" ) subgraph.create_neurons_from(subnet, data=data) subgraph.create_connections_from(subnet, data=data) # subgraph.connections = {key: value for key, value in graph_copy.connections.items()\ # if key in new_connections} else: subgraph = graph_copy subgraph.update_network() subgraph.update_neurons() subgraph.update_connections() else: if neuron_names is not None: filter_neurons = [self.neurons[name] for name in neuron_names] subgraph = self.subgraph_view(filter_neurons=filter_neurons) elif connections is not None: filter_neurons = list( set([neu for c in connections for neu in [c[0], c[1]]]) ) subgraph = self.subgraph_view( filter_connections=connections, filter_neurons=filter_neurons ) else: subgraph = self return subgraph # subgraph.copy(as_view)
[docs] def join_networks(self, networks, mode="consensus"): """Goes through the list of networks and joins them to the current graph.""" assert all([isinstance(network, NervousSystem) for network in networks]) assert len(set([network.name for network in networks])) == len(networks) joined_networks = (network.name for network in networks) all_neurons = {self.name: self.neurons} all_connections = {self.name: self.connections} combined_network = NervousSystem(network=f"{'-'.join(joined_networks)}") for network in networks: all_neurons[network.name] = network.neurons all_connections[network.name] = network.connections if mode == "consensus": # All neurons that are common between networks are chosen. All edges that are common are picked. # The weight of the edge is the average of weights of all networks. neuron_set = [ set(neuron for neuron in network.neurons) for _netname, network in all_neurons.items() ] connection_set = [ set( (edge[0].name, edge[1].name, edge[2]) for edge in network.connections.keys() ) for _netname, network in all_connections.items() ] # print(set(neuron.name for neuron in network.neurons) for _netname, network in all_neurons.items() ) joined_neurs = set.intersection(*neuron_set) joined_conns = set.intersection(*connection_set) combined_network.create_neurons(joined_neurs) for edge in joined_conns: source_neuron, target_neuron, connection_type = edge weights = [ connections.connections[ ( connections.network.neurons[edge[0]], connections.network.neurons[edge[1]], edge[2], ) ].weight for _netname, connections in all_connections.items() ] # edge_weight = np.mean([all_connections[netname].connections[ (all_connections[netname].neurons[source_neuron], all_connections[netname].neurons[target_neuron], connection_type) ].weight for netname in all_connections.keys()]) edge_weight = np.mean(weights) # edge_id = combined_network.add_edge( # combined_network.neurons[source_neuron], combined_network.neurons[target_neuron], # weight=edge_weight, color='k', connection_type=connection_type # ) # connection = Connection( # combined_network.neurons[source_neuron], combined_network.neurons[target_neuron], edge_id, connection_type, weight=edge_weight # ) connection = Connection( combined_network.neurons[source_neuron], combined_network.neurons[target_neuron], connection_type=connection_type, weight=edge_weight, ) connection.set_property("joined_networks", joined_networks) combined_network.connections[ ( combined_network.neurons[source_neuron], combined_network.neurons[target_neuron], connection.uid, ) ] = connection return combined_network
[docs] @record("fold_network") def fold_network( self, fold_by, name=None, data="collect", exceptions=None, self_loops=True, legacy=False, fold_policy=None, ): """ Fold the network based on a partition. Args: fold_by (dict[str, list[str]]): Mapping from each merged-neuron name to the list of original neuron names that should collapse into it. Singleton lists are treated as renames. name (str, optional): Name for the resulting NervousSystem. data (str, optional): One of: - 'collect': Preserve every original edge as a separate parallel in the folded view (default). - 'clean': Sum weights of parallel edges per ``(folded_pre, folded_post, connection_type)`` and union list-valued edge metadata (ligands, neurotransmitters, putative receptors, receptor dicts). exceptions (list[str], optional): Neuron names that should NOT be folded into their class — they pass through to the folded view under their original names. self_loops (bool, optional): If False, intra-class edges (i.e. edges whose endpoints both belong to the same merged class) are dropped from the result. Defaults to True. legacy (bool, optional): Selects between two implementations that should be behaviourally equivalent. Default ``False`` uses the batch path which builds the folded view directly from the partition map in O(V + E). Set ``True`` to use the pre-batch implementation that does N-1 pair-wise ``contract_neurons`` calls per class (O(class_size × growing_supernode_degree) per class — unusable at scale). Preserved during the rollout so callers can bisect parity issues, and scheduled for removal once the fast path has soaked. fold_policy (FoldPolicySet, optional): Per-attribute aggregation policy applied when merging parallel edges in ``data='clean'`` mode. ``None`` (default) → use ``DEFAULT_CONNECTION_FOLD_POLICY``, which encodes the historical contract (weights sum, list-valued metadata set-union, receptor dicts merge with first-observed-value semantics). Both the batch and legacy fold paths honor the policy (Phase 2.1); strict parity tests pin behavior on identical inputs across both. The applied policy is stamped onto the result as ``folded.fold_policy`` for provenance. Returns: NervousSystem: The folded graph. Each merged neuron carries ``constituents``, ``is_merged``, ``constituent_subgraph`` and the merge-policy-resolved ``MERGE_TRACK_ATTRS``. Both code paths produce structurally equivalent NervousSystem objects; see ``_fold_network_batch`` for the precise preservation contract. """ assert isinstance(fold_by, dict), "Enter a dictionary with neuron class\ names as keys and the neurons to fold as values. If there is only one\ neuron in the list of values, the neuron will be renamed to the key." if exceptions is None: exceptions = [] # Validate every fold's destination name BEFORE touching the # graph. Renaming onto an unrelated existing neuron silently # produces duplicate names that crash subsequent copies; the # underlying ``contract_neurons`` has the same guard but # surfacing the error at the fold level gives a clearer message # and avoids partially applying a multi-fold dict. for merged_nodename, nodes_to_fold in fold_by.items(): if merged_nodename in self.neurons and merged_nodename not in nodes_to_fold: raise ValueError( f"Cannot fold into '{merged_nodename}': a neuron with " f"that name already exists in the network and is not " f"one of the neurons being folded " f"({list(nodes_to_fold)!r})." ) # The partition must be disjoint: a neuron can belong to at # most one merged class. Without this check the batch path # silently overwrites ``rename_map`` (last-write-wins) and the # legacy path also misbehaves — both produce silent data loss # / duplication rather than a clean error. Excepted members # are ignored: they pass through to the folded view regardless # of how many classes nominally list them. seen_assignments: dict[str, str] = {} for merged_nodename, nodes_to_fold in fold_by.items(): for n in nodes_to_fold: if n in exceptions: continue prior = seen_assignments.get(n) if prior is not None and prior != merged_nodename: raise ValueError( f"Cannot fold: neuron '{n}' is listed under " f"multiple merged classes ('{prior}' and " f"'{merged_nodename}'). Each neuron may appear " f"in at most one class." ) seen_assignments[n] = merged_nodename # Both legacy and batch paths now honor fold_policy (Phase 2 + # 2.1). The default ``None`` means use DEFAULT_CONNECTION_FOLD_POLICY, # which encodes the pre-Phase-2 contract — strict parity tests # in test_fold_policy_parity.py pin behavior on identical # inputs across both paths. if legacy: result = self._fold_network_legacy( fold_by, name, data, exceptions, self_loops, fold_policy=fold_policy, ) else: result = self._fold_network_batch( fold_by, name, data, exceptions, self_loops, fold_policy=fold_policy, ) # Always stamp the fold policy on the result so provenance is # universal — even in 'collect' mode (no merging) callers can # introspect that this network came from a fold. We import # lazily to avoid a circular import at module load. from .fold_policy import DEFAULT_CONNECTION_FOLD_POLICY, FoldPolicySet if fold_policy is not None: result.fold_policy = fold_policy elif data == "clean": result.fold_policy = DEFAULT_CONNECTION_FOLD_POLICY else: # 'collect' folds don't merge edges; record an empty # policy set so the attribute is always present and # downstream code can branch on its emptiness. result.fold_policy = FoldPolicySet() return result
def _fold_network_legacy( self, fold_by, name, data, exceptions, self_loops, fold_policy=None ): """Pre-batch fold: pair-wise ``contract_neurons`` for every class member after the first, then optional ``contract_connections``. Preserved for parity comparison. O(class_size × supernode_degree) per class — unusable at scale (a 43k-member class on a 17M-edge graph wedges for many hours). The batch implementation (``_fold_network_batch``) is the default. """ # Capture per-fold constituent subgraphs from the *pre-fold* # graph BEFORE any contraction mutates the working copy. Each # captured subnetwork holds the originally-selected set + all # internal edges + per-element attributes, so drill-down can # reconstruct the un-folded state without walking graph history. # # Why we capture at fold granularity (not the internal # contract_neurons pair-wise steps): the fold is the user-level # primitive. Hierarchy depth and provenance should count folds. # A single fold of N neurons is one hierarchy level, irrespective # of how many pair-wise contractions implement it internally. # # ``_silent=True`` keeps the per-fold subnetwork copies out of # the animal log — only the fold_network event is user-relevant. constituent_subgraphs = {} for merged_nodename, nodes_to_fold in fold_by.items(): if len(nodes_to_fold) <= 1: continue present = [ n for n in nodes_to_fold if n not in exceptions and n in self.neurons ] if len(present) <= 1: continue try: constituent_subgraphs[merged_nodename] = self.subnetwork( neuron_names=present, name=f"Constituents of {merged_nodename}", _silent=True, ) except Exception: # Defensive: a subnetwork failure must not block the fold. pass graph_copy = self.copy(copy_type="deep_with_data", name=name) # Internal pair-wise contractions are an implementation detail # of the fold — pass ``_silent=True`` so the animal log only # carries the user-level fold_network event, not N-1 noisy # contract_neurons events per fold. for merged_nodename, nodes_to_fold in fold_by.items(): if len(nodes_to_fold) > 1: merged_node = nodes_to_fold[0] for j in range(1, len(nodes_to_fold)): npair = (merged_node, nodes_to_fold[j]) if npair[0] not in exceptions and npair[1] not in exceptions: graph_copy.contract_neurons( npair, merged_nodename, data=data, self_loops=self_loops, fold_policy=fold_policy, _silent=True, ) merged_node = merged_nodename else: graph_copy.neurons[nodes_to_fold[0]].name = merged_nodename graph_copy.update_network() graph_copy.update_neurons() graph_copy.reassign_connections() def _attach_constituent_subgraphs(result_graph): for merged_name, subg in constituent_subgraphs.items(): if merged_name not in result_graph.neurons: continue merged_neuron = result_graph.neurons[merged_name] merged_neuron.constituent_subgraph = subg result_graph.nodes[merged_neuron]["constituent_subgraph"] = subg if data == "collect": _attach_constituent_subgraphs(graph_copy) return graph_copy if data == "clean": parsed_conns = {} for e, conn in graph_copy.connections.items(): if (e[0], e[1], conn.connection_type) not in parsed_conns: parsed_conns[(e[0], e[1], conn.connection_type)] = [] parsed_conns[(e[0], e[1], conn.connection_type)].append(conn) contracted_graph = graph_copy.contract_connections( parsed_conns, fold_policy=fold_policy, _silent=True ) _attach_constituent_subgraphs(contracted_graph) return contracted_graph def _fold_network_batch( self, fold_by, name, data, exceptions, self_loops, fold_policy=None ): """Batch fold: construct the folded NervousSystem directly from the partition map in O(V + E + Σ class_internal_edges). Preservation contract: - The folded NervousSystem is a fresh object with fresh Neuron and Connection instances. Mutable attribute *values* on the new neurons/connections may share references with the parent's counterparts (same as ``copy(copy_type='deep_with_data')`` and the legacy fold path). Use ``copy(copy_type='deep')`` if you need full attribute isolation. - Merged-neuron provenance matches ``contract_neurons`` exactly: per-member snapshots are captured via the same ``_snapshot`` logic; merge policy on ``MERGE_TRACK_ATTRS`` (all-same → preserved value; mixed → ``MERGED_TYPE`` sentinel) is applied once across the snapshot set; constituents from transitively merged inputs are flattened into the new constituents dict (matching the ``contract_neurons`` line "Fold any constituents the target itself had from prior merges"). - The merged neuron's ``constituent_subgraph`` is the pre-fold subnetwork captured by ``self.subnetwork(neuron_names=present, _silent=True)`` and is mirrored to the nx node dict so subsequent ``deep_with_data`` copies propagate it. - Edges: * ``data='collect'``: every parent edge becomes one folded edge with the original key and ``**edge_data`` preserved. * ``data='clean'``: parallels per ``(folded_pre, folded_post, connection_type)`` are summed (weights) and union'd (ligands / neurotransmitters / putative_neurotrasmitter_receptors / receptors), and the resulting Connection carries ``contraction_data`` mapping each pre-fold edge's ``_id`` → its pre-fold Connection object. Note: the legacy ``data='clean'`` path computes ``contraction_data`` keyed by the post-pair-wise edge IDs; keying by pre-fold IDs here is more directly useful for consumers (e.g. the FlyWire notebook reads ``conn.contraction_data[k].neurotransmitter`` — the attribute survives on either set of Connection objects). - ``self_loops=False`` drops intra-class edges from the folded view, matching ``nx.contracted_nodes(self_loops=False)``. - ``exceptions`` members pass through to the folded view under their original names (matching the legacy path: the pair-wise contraction loop skips pairs containing an excepted neuron). """ # --------------------------------------------------------------- # Step 1: pre-fold constituent_subgraph capture in ONE O(V + E) # pass over the parent graph. # # The earlier draft of this step called ``self.subnetwork(...)`` # per class, which iterates the entire parent edge list per call # (the view-based subnetwork filters by node-membership on every # edge access). That made the capture phase O(classes × E_parent) # — pathological at FlyWire's 510 classes × 17M edges scale. # # The fix here: bucket internal-to-class edges by a single sweep # of ``self.edges(...)``, then build each constituent NervousSystem # directly from its bucket. Per-class cost is now # O(class_size + class_internal_edges); total O(V + E). from collections import defaultdict class_members: dict[str, list[str]] = {} member_to_class: dict[str, str] = {} for merged_nodename, nodes_to_fold in fold_by.items(): if len(nodes_to_fold) <= 1: continue present = [ n for n in nodes_to_fold if n not in exceptions and n in self.neurons ] if len(present) <= 1: continue class_members[merged_nodename] = present for m in present: member_to_class[m] = merged_nodename # Single sweep over parent edges. Keep tuples of (u_name, v_name, # key, edge_data) so the build loop below can construct fresh # Connections without holding refs to the parent's Connection # objects (matches what self.subnetwork(...) would have done). by_class_edges: dict[str, list] = defaultdict(list) if class_members: for u, v, k, edge_data in self.edges(keys=True, data=True): cu = member_to_class.get(u.name) if cu is None: continue if member_to_class.get(v.name) != cu: continue by_class_edges[cu].append((u.name, v.name, k, edge_data)) constituent_subgraphs: dict[str, "NervousSystem"] = {} for cname, members in class_members.items(): try: sub = NervousSystem( self.worm, network=f"Constituents of {cname}", ) # Mirror what ``deep_with_data`` does: each new Neuron # gets the parent's nx node attribute dict as **kwargs. # Same preservation contract as the legacy # ``subnetwork(...)`` path, just scoped to the class. for member_name in members: parent_n = self.neurons[member_name] nx_node_data = dict(self.nodes[parent_n]) Neuron(member_name, sub, **nx_node_data) for un, vn, k, edge_data in by_class_edges.get(cname, ()): src = sub.neurons[un] dst = sub.neurons[vn] # Preserve the parent edge's key + every attribute it # carries. ``Connection(...)`` uses the kwargs to both # ``add_edge`` (nx side) and to ``set_property`` (Python # side) — identical to ``create_connections_from``. ed = dict(edge_data) ct = ed.pop("connection_type", "chemical-synapse") wt = ed.pop("weight", 1) new_conn = Connection( src, dst, k, connection_type=ct, weight=wt, **ed ) sub.connections[(src, dst, new_conn.uid)] = new_conn sub.visualization_metadata = copy.deepcopy(self.visualization_metadata) constituent_subgraphs[cname] = sub except Exception: # Defensive: a single bad class must not block the fold. pass # --------------------------------------------------------------- # Step 2: build the partition map: original_name -> folded_name. # Excepted members pass through unchanged. # --------------------------------------------------------------- rename_map = {} for merged_nodename, nodes_to_fold in fold_by.items(): present = [ n for n in nodes_to_fold if n not in exceptions and n in self.neurons ] if len(present) > 1: for m in present: rename_map[m] = merged_nodename elif len(present) == 1: # Singleton: rename the single member to the class key. # Matches the legacy ``else`` branch (line 660-661 in the # pre-refactor source) which renamed unconditionally. rename_map[present[0]] = merged_nodename # --------------------------------------------------------------- # Step 3: snapshot pre-fold provenance per merged class. Mirrors # ``contract_neurons``'s ``_snapshot`` helper and merge-policy # loop. Single-member classes don't enter this step (their # "merged" neuron is just the renamed source — its existing # constituents/etc. survive via the nx-node-data copy in step 5). # --------------------------------------------------------------- def _snapshot(neuron, n_name): snap = {"name": n_name} for attr in MERGE_TRACK_ATTRS: snap[attr] = getattr(neuron, attr, "") return snap # Phase 2.2: drive the per-attribute merge through apply_policy # (DEFAULT_NEURON_FOLD_POLICY by default). Same source of truth # as the legacy contract_neurons path. from .fold_policy import ( DEFAULT_NEURON_FOLD_POLICY, apply_policy as _apply_policy_neuron, ) neuron_policy_set = ( fold_policy if fold_policy is not None else DEFAULT_NEURON_FOLD_POLICY ) merged_class_snapshots = {} # cname -> {orig_name: snapshot} merged_class_resolved_attrs = {} # cname -> {merge-policy attrs} for merged_nodename in constituent_subgraphs: present = [ n for n in fold_by[merged_nodename] if n not in exceptions and n in self.neurons ] snapshots = {} for m in present: mn = self.neurons[m] # Flatten constituents from any transitively-merged source. if getattr(mn, "constituents", None): for child_name, child_meta in mn.constituents.items(): snapshots.setdefault(child_name, dict(child_meta)) snapshots.setdefault(m, _snapshot(mn, m)) merged_class_snapshots[merged_nodename] = snapshots resolved = {} for policy in neuron_policy_set.policies.values(): values = [meta.get(policy.name) for meta in snapshots.values()] merged = _apply_policy_neuron(policy, values) if merged is None: # No usable constituent values for this attribute — # leave it unset on the resolved dict (matches the # pre-Phase-2.2 "else: leave attr unset" branch). continue resolved[policy.name] = merged merged_class_resolved_attrs[merged_nodename] = resolved # --------------------------------------------------------------- # Step 4: construct the folded NervousSystem and its neurons. # --------------------------------------------------------------- folded = NervousSystem(self.worm, network=name or (self.name + "_folded")) folded.visualization_metadata = copy.deepcopy(self.visualization_metadata) created = set() # 4a. Multi-member merged neurons. for cname, snapshots in merged_class_snapshots.items(): attrs = dict(merged_class_resolved_attrs[cname]) # ``dict(snapshots)`` creates a fresh top-level constituents # dict (snapshot meta values are still shared, matching the # behaviour of contract_neurons line 936). The freshness of # the OUTER dict matters: prevents the # constituents-leak-across-folds bug captured by # ``test_hierarchical_fold_keeps_inner_constituents_intact``. attrs["constituents"] = dict(snapshots) Neuron(cname, folded, **attrs) created.add(cname) # 4b. Singletons and pass-through neurons. Pull attributes from # the parent's nx node attribute dict (same source-of-truth as # ``create_neurons_from(self, data=True)``). for orig_name, orig_neuron in self.neurons.items(): target = rename_map.get(orig_name, orig_name) if target in created: continue nx_data = dict(self.nodes[orig_neuron]) Neuron(target, folded, **nx_data) created.add(target) # --------------------------------------------------------------- # Step 5: aggregate edges into the folded view. # --------------------------------------------------------------- if data == "collect": for u, v, k, edge_data in self.edges(keys=True, data=True): fu = rename_map.get(u.name, u.name) fv = rename_map.get(v.name, v.name) if fu == fv and not self_loops: continue src = folded.neurons[fu] dst = folded.neurons[fv] conn = Connection(src, dst, k, **edge_data) folded.connections[(src, dst, conn.uid)] = conn elif data == "clean": # Aggregate parallels per (folded_pre, folded_post, type). # Phase 2.1: drive the merge through the fold-policy system # so this code path and ``contract_connections`` share one # source of truth for "how do constituent attributes # combine". The default policy (DEFAULT_CONNECTION_FOLD_POLICY) # encodes the pre-Phase-2 contract — strict parity tests # in test_fold_policy_parity.py pin the behavior. # # Unlike contract_connections (which iterates explicit # ``Connection`` objects the caller supplies), the batch # path iterates ``self.edges(data=True)`` and reads # constituent attributes from the networkx edge_data dict. # That's because ``self.connections`` isn't always populated # for fresh edges constructed directly via ``Connection(...)`` # — only the nx graph is canonical. We collect lists of # edge_data dicts per bucket and feed those to apply_policy. from collections import defaultdict from .fold_policy import ( DEFAULT_CONNECTION_FOLD_POLICY, FoldPolicy, apply_policy, ) policy_set = ( fold_policy if fold_policy is not None else DEFAULT_CONNECTION_FOLD_POLICY ) weight_policy = policy_set.policies.get( "weight", FoldPolicy("weight", "scalar", "sum") ) other_policies = [ p for p in policy_set.policies.values() if p.name != "weight" ] # Step A: bucket edge_data dicts by their FOLDED endpoint # triple. Separately track any backing Connection objects so # we can preserve ``contraction_data`` provenance for # ConnectionGroup-registered edges. edge_buckets: dict = defaultdict(list) contraction_data_per_bucket: dict = defaultdict(dict) for u, v, k, edge_data in self.edges(keys=True, data=True): fu = rename_map.get(u.name, u.name) fv = rename_map.get(v.name, v.name) if fu == fv and not self_loops: continue ct = edge_data.get("connection_type", "chemical-synapse") bucket_key = (fu, fv, ct) edge_buckets[bucket_key].append(edge_data) if (u, v, k) in self.connections: orig_conn = self.connections[(u, v, k)] contraction_data_per_bucket[bucket_key][orig_conn._id] = orig_conn # Step B: apply policy per bucket and construct the merged # Connection on the folded view. for (fu, fv, ct), edge_data_list in edge_buckets.items(): weights = [d.get("weight", 0) or 0 for d in edge_data_list] weight = apply_policy(weight_policy, weights) if weight is None: weight = 0 src = folded.neurons[fu] dst = folded.neurons[fv] new_conn = Connection(src, dst, connection_type=ct, weight=weight) new_conn.set_property( "contraction_data", dict(contraction_data_per_bucket[(fu, fv, ct)]), ) for policy in other_policies: values = [d.get(policy.name) for d in edge_data_list] merged = apply_policy(policy, values) # Skip empty results — matches the pre-Phase-2 # ``if b["ligands"]:`` guards. if merged is None or merged == [] or merged == {}: continue new_conn.set_property(policy.name, merged) folded.connections[(src, dst, new_conn.uid)] = new_conn else: raise ValueError(f"Unknown data mode for fold_network: {data!r}") # --------------------------------------------------------------- # Step 6: attach constituent subgraphs and mirror merge-track # attrs to the nx node dict (same shape contract_neurons + # _attach_constituent_subgraphs produce on the legacy path). # --------------------------------------------------------------- for cname, subg in constituent_subgraphs.items(): if cname not in folded.neurons: continue merged_neuron = folded.neurons[cname] merged_neuron.constituent_subgraph = subg folded.nodes[merged_neuron]["constituent_subgraph"] = subg # Mirror constituents + MERGE_TRACK_ATTRS to nx node dict so # subsequent ``deep_with_data`` copies propagate them # (matches the mirror block at the bottom of contract_neurons). folded.nodes[merged_neuron]["constituents"] = merged_neuron.constituents for attr in MERGE_TRACK_ATTRS: if hasattr(merged_neuron, attr): folded.nodes[merged_neuron][attr] = getattr(merged_neuron, attr) return folded # if data == 'collect': # return self # elif data == 'union': # pass # elif data == 'intersect': # pass # else: # raise ValueError("data condition must be 'collect', 'union' or 'intersect'.") # def reassign_nodes(self): # self.update_neurons()
[docs] def adjacency(self, order=None, weighted=False, connection_type=None): """ Output the adjacency matrix for the network ordered by neurons. Args: order: Optional ordered list of neuron names or neuron objects. When omitted, neurons are ordered deterministically by name. weighted (bool): If ``True``, matrix entries are summed edge weights. Otherwise returns a binary adjacency matrix. connection_type: Optional string or list of strings selecting which connection types to include. Supports convenience aliases such as ``"chemical"``, ``"gap-junction"``, and ``"bulk"`` in addition to exact connection type names. """ if order is None: ordered_neurons = sorted( self.neurons.values(), key=lambda neuron: neuron.name ) else: ordered_neurons = [] for item in order: if isinstance(item, Neuron): ordered_neurons.append(item) elif isinstance(item, str): if item not in self.neurons: raise KeyError(f"Neuron '{item}' not found in network") ordered_neurons.append(self.neurons[item]) else: raise TypeError( "order entries must be neuron names or Neuron objects" ) if connection_type is None: selected_types = None include_bulk = False else: selectors = ( [connection_type] if isinstance(connection_type, str) else list(connection_type) ) selected_types = set() include_bulk = False for selector in selectors: if not isinstance(selector, str): raise TypeError("connection_type entries must be strings") key = selector.strip().lower() if key in {"chemical", "chemical-synapse", "chemical_synapse"}: selected_types.add("chemical-synapse") elif key in {"gap", "gap-junction", "gap_junction"}: selected_types.add("gap-junction") elif key == "bulk": include_bulk = True else: selected_types.add(selector) index_map = {neuron: idx for idx, neuron in enumerate(ordered_neurons)} adjacency = np.zeros((len(ordered_neurons), len(ordered_neurons)), dtype=float) for pre, post, edge_data in self.edges(data=True): if pre not in index_map or post not in index_map: continue edge_type = edge_data.get("connection_type") if selected_types is not None: is_bulk = edge_type not in {"chemical-synapse", "gap-junction"} if edge_type not in selected_types and not (include_bulk and is_bulk): continue adjacency[index_map[pre], index_map[post]] += ( edge_data.get("weight", 1.0) if weighted else 1.0 ) if not weighted: adjacency = (adjacency > 0).astype(int) return adjacency
[docs] def reassign_connections(self): """ Reassign connections after folding based on the folding _ids and correcting connection names. """ self._connections = {} for e in self.edges(data=True, keys=True): if "_id" in e[3]: self._connections.update( {(e[0], e[1], e[2]): self.connections[e[3]["_id"]]} ) self._connections[(e[0], e[1], e[2])].pre = e[0] self._connections[(e[0], e[1], e[2])].post = e[1] self._connections[(e[0], e[1], e[2])]._id = (e[0], e[1], e[2]) del e[3]["_id"] else: self._connections.update( {(e[0], e[1], e[2]): self.connections[(e[0], e[1], e[2])]} ) self.connections = self._connections self.update_connections()
# for e in self.in_edges(self.neurons[contracted_name], keys=True, data=True): # self.connections.update({(e[0], e[1], e[2]): self.connections[e[3]['_id']]}) # for e in self.out_edges(self.neurons[contracted_name], keys=True, data=True): # self.connections.update({(e[0], e[1], e[2]): self.connections[e[3]['_id']]})
[docs] @record("contract_neurons") def contract_neurons( self, pair, contracted_name, data="collect", copy_graph=False, self_loops=True, fold_policy=None, ): """ Contract ``target`` into ``source``, redirecting target's edges to source and renaming source to ``contracted_name``. Args: pair (tuple[str, str]): ``(source_name, target_name)``. Source survives (renamed to ``contracted_name``); target is removed and its edges are redirected onto source. contracted_name (str): New name for the surviving neuron. data (str, optional): No-op at the core layer — accepted for backwards compatibility with callers that use it as a hint for post-merge cleanup. The cedne_web backend uses ``data='clean'`` to indicate that ``contract_connections`` should be called after a chain of contractions to collapse parallel edges. copy_graph (bool, optional): If True, work on a deep copy and return it; the original is untouched. If False (default), mutate in place and return None. self_loops (bool, optional): Forwarded to ``nx.contracted_nodes``. If False, drops any self-loop produced by edges between the merged pair. Returns: NervousSystem | None: The new graph if ``copy_graph=True``; ``None`` otherwise (mutates in place). Notes: Tracks merge provenance on the surviving neuron via three new Neuron API surfaces: * ``Neuron.constituents``: dict mapping each pre-merge name to a snapshot ``{'name', 'type'}``. Transitive — merging A+B then merging the result with C yields constituents ``{A, B, C}``. * ``Neuron.is_merged``: bool, True iff ``constituents`` is non-empty. * ``Neuron.constituent_types``: sorted list of distinct constituent types, derived from ``constituents`` so the two cannot drift out of sync. Type-merge policy: * If all constituents share a single type → that type is preserved on the surviving neuron. * If types differ → ``Neuron.type`` is set to the sentinel ``MERGED_TYPE`` (``'merged'``) so analyses that switch on type can branch on the merged case explicitly rather than silently inheriting one constituent's type. """ source_name, target_name = pair if copy_graph: new_graph = self.copy() # FIX: previously this passed a 3-tuple as `pair`, which the # destructure rejects with ValueError. The 2-tuple goes in # `pair`; `contracted_name` is its own argument. new_graph.contract_neurons( (source_name, target_name), contracted_name, data=data, copy_graph=False, self_loops=self_loops, ) return new_graph src = self.neurons[source_name] tgt = self.neurons[target_name] # Refuse to rename onto an unrelated existing neuron. Without this # check the rename `src.name = contracted_name` silently produces # two neurons with the same name in the network (one being the # surviving src, the other being whatever was already there), # corrupting subsequent lookups and triggering "already exists" # crashes deep inside `copy()` / `create_neurons_from`. if ( contracted_name in self.neurons and contracted_name != source_name and contracted_name != target_name ): raise ValueError( f"Cannot contract into '{contracted_name}': a neuron with " f"that name already exists in this network and isn't one of " f"the pair being merged ({source_name!r}, {target_name!r})." ) # ----- Merge-provenance bookkeeping (new in Issue 10A) ------------- # We track constituents on the surviving neuron explicitly, BEFORE # nx.contracted_nodes mutates the graph, because: # (a) transitive merges (A+B → A_m, then A_m+C → A_m) accumulate # cleanly through `setdefault`, # (b) we don't depend on networkx's internal `contraction` attr, # which has changed semantics across versions. # Snapshot helper: capture every MERGE_TRACK_ATTRS value on a # neuron at merge time. Stored under the original name so the # merge policy below can reason about all of them uniformly. def _snapshot(neuron, name): snap = {"name": name} for attr in MERGE_TRACK_ATTRS: snap[attr] = getattr(neuron, attr, "") return snap if not getattr(src, "constituents", None): # First merge for src — record its pre-merge identity. # src.name is still the original here (rename happens below). src.constituents = {src.name: _snapshot(src, src.name)} else: # Re-merge: src is already a merged neuron. The dict it # carries may be shared with previously-captured constituent # subgraphs (``create_neurons_from`` propagates node attrs # by reference, not deep copy, so the same dict is reachable # through any nested ``constituent_subgraph`` taken before # this contraction). The setdefault calls below mutate that # dict in place; without taking ownership of a fresh copy # first, the captured snapshot's view of its constituents # would silently grow to include this fold's other inputs # too — corrupting drill-down provenance. src.constituents = dict(src.constituents) # Fold any constituents the target itself had from prior merges so # nested merges produce a flat list of original neurons. if getattr(tgt, "constituents", None): for c_name, c_meta in tgt.constituents.items(): src.constituents.setdefault(c_name, c_meta) # Add the target as a constituent (keyed by its current name, # which is its original name unless the target itself was already # a merged neuron — in which case the fold above already covered # the originals and target_name is the merged-name placeholder # we still record for traceability). src.constituents.setdefault(target_name, _snapshot(tgt, target_name)) # Apply the per-attribute fold policy across every constituent's # snapshot. The default policy (DEFAULT_NEURON_FOLD_POLICY) maps # type/category/modality → categorical same_or_merged, which is # exactly what the pre-Phase-2.2 inline loop did (all-same → # keep; mixed → MERGED_TYPE sentinel; empty → skip setattr). # Phase 2.2 routes that decision through apply_policy so callers # can override per fold (e.g. mode/keep_all on a categorical, or # add policies for new categorical attrs). from .fold_policy import ( DEFAULT_NEURON_FOLD_POLICY, apply_policy, ) neuron_policy_set = ( fold_policy if fold_policy is not None else DEFAULT_NEURON_FOLD_POLICY ) # The historical loop iterated MERGE_TRACK_ATTRS; the policy set # is now the source of truth. We still iterate MERGE_TRACK_ATTRS # for the nx-mirror step below so paths that reconstruct neurons # via ``create_neurons_from(data=True)`` see every tracked attr # — even un-registered ones — propagated to the node dict. for policy in neuron_policy_set.policies.values(): values = [meta.get(policy.name) for meta in src.constituents.values()] merged = apply_policy(policy, values) if merged is None: # Empty constituent values → leave the attribute alone # (matches the pre-Phase-2.2 "if len(values) == 0: # do nothing" branch). continue setattr(src, policy.name, merged) # Mirror `constituents` and any updated tracked attributes to # the networkx node attribute dict so paths that reconstruct # neurons via `create_neurons_from(data=True)` (used by # contract_connections, copy(copy_type='deep_with_data'), etc.) # propagate the merge state instead of dropping it. self.nodes[src]["constituents"] = src.constituents for attr in MERGE_TRACK_ATTRS: if hasattr(src, attr): self.nodes[src][attr] = getattr(src, attr) # NOTE on provenance: pair-wise contraction is an implementation # detail of folding. We deliberately do NOT attach a # ``constituent_subgraph`` here — that record belongs at the # user-level operation (``fold_network``), which captures the # full flat subgraph of the originally-selected set on the # resulting merged neuron. Counting pair-wise steps would # over-count hierarchy depth (a single fold of N neurons would # look like N-1 nested levels rather than 1). # ----- Existing edge-contraction flow ------------------------------ for _cid, conn in src.get_connections().items(): conn.set_property("_id", conn._id) for _cid, conn in tgt.get_connections().items(): conn.set_property("_id", conn._id) nx.contracted_nodes(self, src, tgt, copy=False, self_loops=self_loops) src.name = contracted_name self.update_neurons()
[docs] @record("contract_connections") def contract_connections(self, contraction_dict, fold_policy=None): """Contract parallel connections into one supernode-edge per group. Args: contraction_dict: Mapping of ``(neuron_a, neuron_b, conn_type)`` tuples to the list of constituent ``Connection`` objects that should be folded into a single supernode-edge. fold_policy (FoldPolicySet, optional): Per-attribute aggregation policy. When ``None`` (default) the historical contract is used (weights sum; ligands / neurotransmitters / putative pairs set-union; receptors dict-union with first-observed-value semantics). Pass a custom ``FoldPolicySet`` to override aggregators or add new dataset shapes. Returns: NervousSystem: The folded graph. The applied policy is attached as ``result.fold_policy`` so consumers (and re-fold operations) can read back exactly how the merge decisions were made. """ # Lazy import: fold_policy lives in the same package but # importing at module load could create a cycle through # neuron/connection if those start importing policy types. from .fold_policy import ( DEFAULT_CONNECTION_FOLD_POLICY, FoldPolicy, apply_policy, ) policy_set = ( fold_policy if fold_policy is not None else DEFAULT_CONNECTION_FOLD_POLICY ) # Weight is a Connection-constructor argument, not a generic # property, so it's pulled out and handled separately. The # rest of the registered policies drive set_property calls. weight_policy = policy_set.policies.get( "weight", FoldPolicy("weight", "scalar", "sum") ) other_policies = [p for p in policy_set.policies.values() if p.name != "weight"] empty_graph_copy = NervousSystem(self.worm, self.name + "_copy") empty_graph_copy.create_neurons_from(self, data=True) _connections = {} for contraction, conns in contraction_dict.items(): # contraction_data preserves the original Connection objects # so downstream code can introspect what was merged. Not # policy-driven — it's bookkeeping, always carried. contraction_data = {conn._id: conn for conn in conns} weights = [getattr(c, "weight", 0) for c in conns] weight = apply_policy(weight_policy, weights) if weight is None: weight = 0 # Connection() requires a numeric weight n1 = empty_graph_copy.neurons[contraction[0].name] n2 = empty_graph_copy.neurons[contraction[1].name] new_conn = Connection(n1, n2, connection_type=contraction[2], weight=weight) new_conn.set_property("contraction_data", copy.copy(contraction_data)) for policy in other_policies: values = [getattr(c, policy.name, None) for c in conns] merged = apply_policy(policy, values) # Skip empty results to match the pre-Phase-2 behavior # of only setting these properties when there's actually # something to record. if merged is None or merged == [] or merged == {}: continue new_conn.set_property(policy.name, merged) _connections[(n1, n2, new_conn.uid)] = new_conn empty_graph_copy.connections = _connections empty_graph_copy.update_network() empty_graph_copy.update_connections() empty_graph_copy.update_neurons() # Stamp the policy on the result. Per-fold provenance — re-folding # later can read this back, the UI can surface it, and the # exported NWB / pickle carries the merge contract alongside the # data it produced. empty_graph_copy.fold_policy = policy_set return empty_graph_copy
[docs] def copy_data_from(self, nervous_system): """ Copies data from another nervous system to this one. Args: nervous_system (NervousSystem): The nervous system to copy data from. Returns: None """
[docs] def neurons_have(self, key): """Returns neuron attributes""" return nx.get_node_attributes(self, key)
[docs] def connections_have(self, key): """Gets connection attributes""" return nx.get_edge_attributes(self, key)
[docs] def connections_between(self, neuron1, neuron2, directed=True): """Returns connections between neurons in neuron list.""" if directed: return neuron1.get_connections(neuron2, direction="out") else: return neuron1.get_connections(neuron2)
def __filter_node__(self, node): """ Checks if a specific node is filtered within the network. Parameters: node (Any): The node to check for filtering. Returns: bool: True if the node is filtered, False otherwise. """ return node in self._filtered_nodes def __filter_edge__(self, neuron_1, neuron_2, key): """ Checks if a specific edge is filtered within the network. Parameters: n1: The starting node of the edge. n2: The ending node of the edge. key: The key identifying the edge. Returns: Boolean: True if the edge is in the filtered edges, False otherwise. """ return (neuron_1, neuron_2, key) in self._filtered_edges
[docs] def return_network_where( self, neurons_have=None, connections_have=None, condition="AND" ): """ Returns a subgraph view of the current network based on the specified conditions. Parameters: neurons_have (dict): A dictionary of neuron attributes and their corresponding values. The subgraph will only include neurons that have all the specified attributes and values. Defaults to an empty dictionary. connections_have (dict): A dictionary of connection attributes and their corresponding values. The subgraph will only include connections that have all the specified attributes and values. Defaults to an empty dictionary. condition (str): The condition to apply when filtering neurons and connections. Can be 'AND' or 'OR'. Default is 'AND'. Returns: networkx.classes.Graph: A subgraph view of the current network that satisfies the specified conditions. """ ## First filter the neurons if neurons_have is None: neurons_have = {} if connections_have is None: connections_have = {} total_node_list = [] filtered_node_list = None filtered_edge_list = None if len(neurons_have): for key, value in neurons_have.items(): each_filter = [] for node, val in self.neurons_have(key).items(): if val == value: each_filter.append(node) total_node_list.append(each_filter) if condition == "AND": filtered_node_list = set( [ node for _n, node in self.neurons.items() if all(node in sublist for sublist in total_node_list) ] ) elif condition == "OR": filtered_node_list = set( [ node for _n, node in self.neurons.items() if any(node in sublist for sublist in total_node_list) ] ) else: raise ValueError("condition must be 'AND' or 'OR'") ## Then filter the connections total_edge_list = [] if len(connections_have): for key, value in connections_have.items(): each_filter = [] for edge, val in self.connections_have(key).items(): if val == value: each_filter.append(edge) total_edge_list.append(each_filter) # print(totalList) if condition == "AND": filtered_edge_list = set( [ edge for _e, edge in self.connections.items() if all(_e in sublist for sublist in total_edge_list) ] ) elif condition == "OR": filtered_edge_list = set( [ edge for _e, edge in self.connections.items() if any(_e in sublist for sublist in total_edge_list) ] ) else: raise ValueError("condition must be 'AND' or 'OR'") return self.subgraph_view( filter_neurons=filtered_node_list, filter_connections=filtered_edge_list )
[docs] def copy(self, name=None, copy_type="deep"): """ Returns a deep copy of the Nervous System object. Parameters: as_view (bool): If True, the copy will be a view of the original graph. Default is False. Returns: object: a deep copy of the Nervous System object. """ if copy_type == "shallow": return super().copy(as_view=False) elif copy_type == "deep": return copy.deepcopy(self) elif copy_type == "deep_with_data": deep_copy = NervousSystem(self.worm, network=name or self.name + "_copy") deep_copy.create_neurons_from(self, data=True) deep_copy.create_connections_from(self, data=True) deep_copy.visualization_metadata = copy.deepcopy( self.visualization_metadata ) return deep_copy elif copy_type == "deep_without_data": deep_copy = NervousSystem(self.worm, network=name or self.name + "_copy") deep_copy.create_neurons_from(self, data=False) deep_copy.create_connections_from(self, data=False) deep_copy.visualization_metadata = copy.deepcopy( self.visualization_metadata ) return deep_copy else: raise ValueError("copy_type must be 'deep', 'shallow'")
[docs] def copy_neurons(self, name=None, data=False): """Copies the neurons from the network and creates a new network with them""" new_network = NervousSystem(self.worm, network=name or self.name + "_copy") new_network.create_neurons_from(self, data=data) return new_network
[docs] def subgraph_view(self, filter_neurons=None, filter_connections=None): """Creates a read only view of a subgraph""" if not filter_neurons: self._filtered_edges = filter_connections return nx.subgraph_view(self, filter_edge=self.__filter_edge__) if not filter_connections: self._filtered_nodes = filter_neurons return nx.subgraph_view(self, filter_node=self.__filter_node__) self._filtered_nodes = filter_neurons self._filtered_edges = filter_connections return nx.subgraph_view( self, filter_node=self.__filter_node__, filter_edge=self.__filter_edge__ )
[docs] def search_motifs(self, motif): """ Search for a motif in the network structure. """ matcher = nx.algorithms.isomorphism.DiGraphMatcher(self, motif) motif_graphs = [] for subgraph in matcher.subgraph_isomorphisms_iter(): subgraph_inverse = { motif_node: node for node, motif_node in subgraph.items() } members = { edge: (subgraph_inverse[edge[0]], subgraph_inverse[edge[1]]) for edge in motif.edges } motif_graphs.append(members) return motif_graphs
[docs] def shortest_path(self, source, target, weight=None, method="dijkstra"): """ Finds a single shortest path between two neurons in the network. """ return nx.shortest_path(self, source, target, weight=weight, method=method)
[docs] def shortest_paths(self, source, target, weight=None, method="dijkstra"): """ Finds all shortest paths between two neurons in the network. """ return nx.all_shortest_paths(self, source, target, weight=weight, method=method)
[docs] def export_graph(self, path, fmt="dot"): """ Exports the graph to the specified path. Parameters: path (str): The path to save the exported graph. Returns: None """ if fmt == "dot": nx.drawing.nx_pydot.write_dot(self, path) elif fmt == "graphviz": nx.drawing.nx_agraph.write_dot(self, path) elif fmt == "nx": nx.write_graphml(self, path) elif fmt == "json": with open(path, "w", encoding="utf-8") as f: jn = nx.cytoscape_data(self, path) json.dump(jn, f, ensure_ascii=False, indent=4) elif fmt == "gml": nx.write_gml(self, path) elif fmt == "graphml": nx.write_graphml(self, path) else: raise ValueError( "format must be 'dot', 'graphviz', 'nx', 'json', 'gml', or 'graphml'" )
[docs] def to_dict(self) -> dict: """Serialize the full network to a plain Python dictionary. Composes ``Neuron.to_dict()`` and ``Connection.to_dict()`` for each node and edge, then attaches group membership, group summaries, network-level metadata, and visualization hints. Returns: dict: A JSON-compatible dictionary with keys: - ``name``: Network name. - ``organism``: Organism name (from worm), if available. - ``nodes``: List of neuron dicts. - ``links``: List of connection dicts. - ``groups``: List of group summary dicts. - ``visualization_metadata``: Arbitrary viz hints. """ # -- Group membership lookups -- node_groups = {} link_groups = {} for gname, group in self.groups.items(): if hasattr(group, "neurons"): # NeuronGroup for nname in group.neurons: node_groups.setdefault(nname, []).append(gname) if hasattr(group, "connections"): # ConnectionGroup for conn_id in group.connections: key = ( conn_id[0].name if hasattr(conn_id[0], "name") else conn_id[0], conn_id[1].name if hasattr(conn_id[1], "name") else conn_id[1], conn_id[2], ) link_groups.setdefault(key, []).append(gname) # -- Nodes -- nodes = [] for nname, neuron in self.neurons.items(): nd = neuron.to_dict() if nname in node_groups: nd["groups"] = node_groups[nname] nodes.append(nd) # -- Links -- links = [] for (u, v, k), conn in self.connections.items(): ld = conn.to_dict() lk = (u.name, v.name, k) if lk in link_groups: ld["groups"] = link_groups[lk] links.append(ld) # -- Groups summary -- groups_list = [] from .neuron import NeuronGroup from .connection import ConnectionGroup for gname, g in self.groups.items(): if isinstance(g, NeuronGroup): groups_list.append( { "name": gname, "type": "neuron", "count": len(g), "members": list(g.neurons), } ) elif isinstance(g, ConnectionGroup): groups_list.append( { "name": gname, "type": "connection", "count": len(g), "members": [f"{c.pre.name}->{c.post.name}" for c in g.members], } ) # -- Assemble result with network-level attributes -- result = { "name": self.name, "nodes": nodes, "links": links, "groups": groups_list, "visualization_metadata": getattr(self, "visualization_metadata", {}), } # Include organism metadata if present if hasattr(self, "worm") and self.worm: result["organism"] = getattr(self.worm, "name", None) return result
[docs] def remove_unconnected_neurons(self): """ Removes neurons that are not connected to any other neurons. Returns: None """ self.remove_nodes_from(list(nx.isolates(self))) self.update_neurons()
[docs] def make_neuron_group(self, members, group_name=None): """ Creates a neuron group with the specified members. Parameters: members (Iterable[Union[Neuron, str]]): The members of the group. Each entry may be a ``Neuron`` instance or a neuron name (string); strings are resolved against ``self.neurons``. group_name (str, optional): The name of the neuron group. Auto-generated if omitted. Returns: NeuronGroup: The created neuron group. """ return NeuronGroup(self, members, group_name)
[docs] def delete_neuron_group(self, groupname): """ Deletes a neuron group with the specified name. Parameters: groupname (str): The name of the neuron group to be deleted. Returns: None """ del self.groups[groupname]
[docs] def make_connection_group(self, members, group_name=None): """ Creates a connection group with the specified members. Parameters: members (Iterable[Connection]): The members of the group. Each entry must be a ``Connection`` instance — connection identity is the ``(pre, post, uid)`` triple, not a string name. group_name (str, optional): The name of the connection group. Auto-generated if omitted. Returns: ConnectionGroup: The created connection group. """ return ConnectionGroup(self, members, group_name)
[docs] def delete_connection_group(self, groupname): """ Deletes a connection group with the specified name. Parameters: groupname (str): The name of the connection group to be deleted. Returns: None """ del self.groups[groupname]
def __delete__(self, neuron): """ Deletes the object from the network. """ self.remove_node(neuron) self.update_neurons()