Source code for cedne.core.connection

"""
Connection primitives for CeDNe.

This module defines the core data structures for representing connections between cells
in the nervous system. It includes:

- `Connection`: Base class for all types of connections.
- `ChemicalSynapse`: Specialized connection for chemical synapses.
- `GapJunction`: Specialized connection for gap junctions.
- `BulkConnection`: Specialized connection for bulk connections.
- `ConnectionGroup`: Container for managing sets of connections.
- `Path`: A sequence of connections between cells.

Each connection type can maintain its own properties and weights while sharing common
functionality through the base Connection class.
"""

__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"

import networkx as nx
from collections import Counter
import numpy as np
from .io import generate_random_string
from .source import Citable, serialize_citations
from typing import TYPE_CHECKING
from numbers import Number

if TYPE_CHECKING:
    from .neuron import Neuron


[docs] class Connection(Citable): """This class represents a connection between two cells.""" def __init__( self, pre: "Neuron", post: "Neuron", uid=None, connection_type="chemical-synapse", **kwargs, ): """ Initializes a new instance of the Connection class. Args: pre (Neuron): The neuron sending the connection. post (Neuron): The neuron receiving the connection. uid (int, optional): The unique identifier for the connection. connection_type (str, optional): The type of the connection. weight (float, optional): The weight of the connection. Must be a numeric value. Raises: ValueError: If weight is not a numeric value. AssertionError: If pre and post neurons are from different networks. """ Citable.__init__(self) # provides self.citations = {} self.pre = pre self.post = post self.network = post.network # Validate weight weight = kwargs.pop("weight", 1) if not (isinstance(weight, Number) or isinstance(weight, np.number)): raise ValueError( "Weight must be a numeric value, but is of type {}".format(type(weight)) ) self.weight = weight if pre.network != post.network: raise AssertionError( "The Nervous Systems of the pre and post neurons must be the same." ) if not uid: self.uid = self.network.add_edge( pre, post, weight=self.weight, connection_type=connection_type, **kwargs ) else: self.uid = uid self.network.add_edge( pre, post, key=uid, weight=self.weight, connection_type=connection_type, **kwargs, ) self._id = (pre, post, self.uid) self.connection_type = connection_type self.pre.out_connections[self._id] = self self.post.in_connections[self._id] = self for key, value in kwargs.items(): self.set_property(key, value) def _parent_citables(self): """Walk citations up through containing ConnectionGroups and the NervousSystem.""" parents = [] if self.network is not None: for group in self.network.groups.values(): # O(1) membership via the _id->Connection dict each ConnectionGroup maintains if ( isinstance(group, ConnectionGroup) and group.connections.get(self._id) is self ): parents.append(group) parents.append(self.network) return parents @property def by_name(self): """ Returns the connecting neuron names (Pre,Post) """ return (self.pre.name, self.post.name)
[docs] def update_weight(self, weight, delta=False): """Sets the connection weight""" if not delta: self.weight = weight else: self.weight += weight nx.set_edge_attributes(self.network, {self._id: {"weight": self.weight}})
[docs] def set_property(self, key, val): """Sets an attribute for the class""" setattr(self, key, val) nx.set_edge_attributes(self.network, {self._id: {key: val}})
[docs] def get_property(self, key): """Gets an attribute for the class""" return getattr(self, key)
[docs] def to_dict(self) -> dict: """Serialize connection to a plain Python dictionary. Returns a dict with guaranteed keys: ``source``, ``target``, ``weight``, ``type``. Optionally includes ``ligands``, ``neurotransmitters``, and ``receptors`` if present. Returns: dict: A JSON-compatible dictionary representation. """ d = { "source": self.pre.name, "target": self.post.name, "weight": self.weight, "type": self.connection_type, } # Semantic edge metadata for attr in ("ligands", "neurotransmitters"): val = getattr(self, attr, None) if val: d[attr] = val if hasattr(self, "receptors") and self.receptors: d["receptors"] = ( list(self.receptors.keys()) if isinstance(self.receptors, dict) else self.receptors ) # Citations attached directly to this connection (not the inherited chain) if hasattr(self, "citations") and self.citations: d["citations"] = serialize_citations(self.citations) return d
[docs] class ChemicalSynapse(Connection): """This is a convenience class that represents connections of type chemical synapses.""" def __init__( self, pre, post, uid=0, connection_type="chemical-synapse", weight=1, **kwargs ): super().__init__( pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs ) self.position = kwargs.pop("position", {"AP": 0, "LR": 0, "DV": 0})
[docs] class GapJunction(Connection): """This is a convenience class that represents connections of type gap junctions.""" def __init__( self, pre, post, uid=1, connection_type="gap-junction", weight=1, **kwargs ): super().__init__( pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs ) self.position = kwargs.pop("position", {"AP": 0, "LR": 0, "DV": 0})
[docs] class BulkConnection(Connection): """This is a convenience class that represents connections of type neuropeptide-receptors.""" def __init__(self, pre, post, uid, connection_type, weight=1, **kwargs): super().__init__( pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs )
[docs] class ConnectionGroup(Citable): """This is a group of connections in the network""" def __init__(self, network, members=None, group_name=None) -> None: """ Initializes a new instance of the ConnectionGroup class. Parameters: network (NervousSystem): The owning nervous system. members (Iterable[Connection], optional): The connections in the group. Each entry must be a ``Connection`` instance — connection identity is the ``(pre, post, uid)`` triple, not a string. Defaults to empty. group_name (str, optional): The name of the connection group. Auto-generated if omitted. Returns: None """ Citable.__init__(self) # provides self.citations = {} if group_name is None: self.group_name = "Group-" + generate_random_string(8) else: self.group_name = group_name if members is None: members = [] else: assert all( [isinstance(m, Connection) for m in members] ), "Connection group members must be of type Connection" self.members = members self.connections = {m._id: m for m in members} self.network = network self.neurons = set([neuron for m in members for neuron in (m.pre, m.post)]) assert ( self.group_name not in self.network.groups ), f"Group name {self.group_name} already exists in the network" self.network.groups.update({self.group_name: self}) def _parent_citables(self): """Walk citations up to the containing NervousSystem.""" return (self.network,) if self.network is not None else () def __iter__(self): """ Returns an iterator over the members of the group. """ return iter(self.connections)
[docs] def clear(self): """ Removes all connections from the group. """ self.connections = {} self.members = []
[docs] def items(self): """ Returns the itemized connection dictionary """ return self.connections.items()
[docs] def keys(self): """Returns the IDs for the Connection Group""" return list(self.connections.keys())
[docs] def values(self): """Returns the IDs for the Connection Group""" return list(self.connections.values())
def __len__(self): """ Returns the number of members in the group. """ return len(self.connections) def __contains__(self, member): """ Returns True if the connection with the specified name is in the group, False otherwise. """ if TYPE_CHECKING: from .neuron import Neuron if isinstance(member, Neuron): return member in self.neurons return member in self.connections def __getitem__(self, connection_id): """ Returns the connection with the specified name in the group. """ return self.connections[connection_id] def __setitem__(self, connection_id, connection): """ Sets the connection with the specified name in the group. """ assert isinstance( connection, Connection ), "Connection must be of type Connection" self.connections[connection_id] = connection
[docs] def update(self, member_dict): """ Updates the list of members in the group. """ assert all( [ isinstance(connection, Connection) for ename, connection in member_dict.items() ] ), "Connection group members must be\ of type Connection" self.connections.update(member_dict) self.members = list(self.connections.values())
[docs] def pop(self, connection_id): """ Deletes the connection with the specified name from the group. """ self.connections.pop(connection_id)
[docs] def set_property(self, property_name, property_value): """ Sets a new property attribute for all connections in the group. """ for connection in self.members: connection.set_property(property_name, property_value)
[docs] def get_property(self, property_name): """ Gets the property attribute for all connections in the group. """ return [connection.get_property(property_name) for connection in self.members]
[docs] def update_weights(self, weight, delta=False): """ Updates weights for all connections in the group. Args: weight (float): The new weight value or weight delta. delta (bool, optional): If True, weight is added to current weights. \ If False, weight replaces current weights. Raises: ValueError: If weight is not a numeric value. """ if not (isinstance(weight, Number) or isinstance(weight, np.number)): raise ValueError( "Weight must be a numeric value, but is of type {}".format(type(weight)) ) for connection in self.members: connection.update_weight(weight, delta)
[docs] def update_weights_by_function(self, weight_function): """ Updates weights for all connections using a custom function. Args: weight_function (callable): A function that takes a Connection object and returns a new weight value. Raises: ValueError: If weight_function is not callable or returns non-numeric values. """ if not callable(weight_function): raise ValueError("weight_function must be callable") for connection in self.members: new_weight = weight_function(connection) if not ( isinstance(new_weight, Number) or isinstance(new_weight, np.number) ): raise ValueError("weight_function must return numeric values") connection.update_weight(new_weight)
[docs] def filter_by_type(self, connection_type): """ Returns a new ConnectionGroup containing only connections of the specified type. Args: connection_type (str): The type of connections to filter for. Returns: ConnectionGroup: A new group containing only the filtered connections. """ filtered_members = [ m for m in self.members if m.connection_type == connection_type ] return ConnectionGroup( self.network, filtered_members, f"{self.group_name}-{connection_type}" )
[docs] def filter_by_property(self, property_name, property_value): """ Returns a new ConnectionGroup containing only connections with the specified property value. Args: property_name (str): The name of the property to filter by. property_value: The value to match. Returns: ConnectionGroup: A new group containing only the filtered connections. """ filtered_members = [ m for m in self.members if hasattr(m, property_name) and getattr(m, property_name) == property_value ] return ConnectionGroup( self.network, filtered_members, f"{self.group_name}-{property_name}-{property_value}", )
[docs] def filter_by_function(self, filter_function): """ Returns a new ConnectionGroup containing only connections that pass the filter function. Args: filter_function (callable): A function that takes a Connection object and returns True if the connection should be included. Returns: ConnectionGroup: A new group containing only the filtered connections. """ if not callable(filter_function): raise ValueError("filter_function must be callable") filtered_members = [m for m in self.members if filter_function(m)] return ConnectionGroup( self.network, filtered_members, f"{self.group_name}-filtered" )
[docs] def get_statistics(self): """ Returns statistics about the connections in the group. Returns: dict: A dictionary containing statistics about the connections. """ stats = { "count": len(self.members), "weight_mean": sum(m.weight for m in self.members) / len(self.members) if self.members else 0, "weight_min": min(m.weight for m in self.members) if self.members else 0, "weight_max": max(m.weight for m in self.members) if self.members else 0, "types": Counter([m.connection_type for m in self.members]), } return stats
# ---- Set operations ----
[docs] def union( self, other: "ConnectionGroup", group_name: str = None ) -> "ConnectionGroup": """Return a new ConnectionGroup containing connections from both groups. Args: other: Another ConnectionGroup to union with. group_name: Optional name for the resulting group. Returns: A new ConnectionGroup with the combined members. """ assert isinstance(other, ConnectionGroup), "Operand must be a ConnectionGroup" assert ( self.network is other.network ), "Both groups must belong to the same network" merged = {**self.connections, **other.connections} name = group_name or f"{self.group_name}_union_{other.group_name}" return ConnectionGroup( self.network, members=list(merged.values()), group_name=name )
[docs] def intersection( self, other: "ConnectionGroup", group_name: str = None ) -> "ConnectionGroup": """Return a new ConnectionGroup containing only connections present in both groups. Args: other: Another ConnectionGroup to intersect with. group_name: Optional name for the resulting group. Returns: A new ConnectionGroup with the shared members. """ assert isinstance(other, ConnectionGroup), "Operand must be a ConnectionGroup" assert ( self.network is other.network ), "Both groups must belong to the same network" common_keys = set(self.connections.keys()) & set(other.connections.keys()) members = [self.connections[k] for k in common_keys] name = group_name or f"{self.group_name}_intersect_{other.group_name}" return ConnectionGroup(self.network, members=members, group_name=name)
[docs] def difference( self, other: "ConnectionGroup", group_name: str = None ) -> "ConnectionGroup": """Return a new ConnectionGroup containing connections in self but not in other. Args: other: Another ConnectionGroup to subtract. group_name: Optional name for the resulting group. Returns: A new ConnectionGroup with the difference members. """ assert isinstance(other, ConnectionGroup), "Operand must be a ConnectionGroup" assert ( self.network is other.network ), "Both groups must belong to the same network" diff_keys = set(self.connections.keys()) - set(other.connections.keys()) members = [self.connections[k] for k in diff_keys] name = group_name or f"{self.group_name}_diff_{other.group_name}" return ConnectionGroup(self.network, members=members, group_name=name)
def __or__(self, other: "ConnectionGroup") -> "ConnectionGroup": """Operator ``|`` — union.""" return self.union(other) def __and__(self, other: "ConnectionGroup") -> "ConnectionGroup": """Operator ``&`` — intersection.""" return self.intersection(other) def __sub__(self, other: "ConnectionGroup") -> "ConnectionGroup": """Operator ``-`` — difference.""" return self.difference(other)
[docs] class Path(ConnectionGroup): """This is a sequence of Connections in the network.""" def __init__(self, network, members=None, group_name=None): if group_name is None: group_name = "Path-" + generate_random_string(8) if members is None: members = [] self.source = None self.target = None else: assert all( [isinstance(m, Connection) for m in members] ), "Path members must be of type Connection" if len(members) > 1: # Check continuity only if there's more than one connection for i in range(len(members) - 1): if members[i].post != members[i + 1].pre: raise AssertionError( "Path members must be continuous connections from source to target" ) if members: self.source = members[0].pre self.target = members[-1].post else: self.source = None self.target = None super().__init__(network, members, group_name)
[docs] def update(self, member_dict): """ Updates the list of members in the group. """ raise NotImplementedError( f"Cannot update connections in {self.__class__.__name__}" )
[docs] def pop(self, connection_id): """ Deletes the connection with the specified name from the group. """ raise NotImplementedError( f"Cannot remove connections from {self.__class__.__name__}" )
[docs] def get_length(self): """ Returns the number of connections in the path. Returns: int: The number of connections in the path. """ return len(self.members)
[docs] def get_total_weight(self): """ Returns the sum of weights of all connections in the path. Returns: float: The sum of connection weights. """ return sum(conn.weight for conn in self.members)
[docs] def get_average_weight(self): """ Returns the average weight of connections in the path. Returns: float: The average connection weight. """ if not self.members: return 0.0 return self.get_total_weight() / len(self.members)
[docs] def get_min_weight(self): """ Returns the minimum weight among all connections in the path. Returns: float: The minimum connection weight. """ if not self.members: return 0.0 return min(conn.weight for conn in self.members)
[docs] def get_max_weight(self): """ Returns the maximum weight among all connections in the path. Returns: float: The maximum connection weight. """ if not self.members: return 0.0 return max(conn.weight for conn in self.members)
[docs] def reverse(self): """ Creates a new path with connections in reverse order. Returns: Path: A new path with reversed connections. """ if not self.members: return Path(self.network) reversed_connections = [] for conn in reversed(self.members): # Create a new connection with reversed pre/post neurons reversed_conn = Connection( conn.post, conn.pre, connection_type=conn.connection_type, weight=conn.weight, ) reversed_connections.append(reversed_conn) return Path(self.network, reversed_connections, f"{self.group_name}-reversed")
[docs] def concatenate(self, other_path): """ Creates a new path by concatenating this path with another path. Args: other_path (Path): The path to concatenate with. Returns: Path: A new path containing all connections from both paths. Raises: AssertionError: If the paths cannot be concatenated (target of first path must match source of second path). """ if not self.members: return Path(self.network, other_path.members.copy()) if not other_path.members: return Path(self.network, self.members.copy()) if self.target != other_path.source: raise AssertionError( "Cannot concatenate paths: target of first path must match source of second path" ) combined_connections = self.members.copy() + other_path.members.copy() return Path( self.network, combined_connections, f"{self.group_name}-{other_path.group_name}", )
[docs] def is_valid(self): """ Checks if the path is valid (all connections are continuous). Returns: bool: True if the path is valid, False otherwise. """ if not self.members: return True try: for i in range(len(self.members) - 1): if self.members[i].post != self.members[i + 1].pre: return False return True except Exception: return False
[docs] def get_neurons(self): """ Returns a list of neurons in the path in order. Returns: list: List of neurons from source to target. """ if not self.members: return [] neurons = [self.source] for conn in self.members: neurons.append(conn.post) return neurons
[docs] def get_connection_types(self): """ Returns a list of connection types in the path in order. Returns: list: List of connection types. """ return [conn.connection_type for conn in self.members]