Source code for cedne.core.connection

"""
Connection primitives for CeDNe.

This module defines the core data structures for representing connections between cells
in the nervous system. It includes:

- `Connection`: Base class for all types of connections.
- `ChemicalSynapse`: Specialized connection for chemical synapses.
- `GapJunction`: Specialized connection for gap junctions.
- `BulkConnection`: Specialized connection for bulk connections.
- `ConnectionGroup`: Container for managing sets of connections.
- `Path`: A sequence of connections between cells.

Each connection type can maintain its own properties and weights while sharing common
functionality through the base Connection class.
"""

__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"

import networkx as nx
from collections import Counter
import numpy as np
from .io import generate_random_string
from typing import Optional, List, Dict, Any, Union, TYPE_CHECKING
from numbers import Number

if TYPE_CHECKING:
    from .neuron import Neuron
    from .network import NervousSystem

[docs] class Connection: ''' This class represents a connection between two cells. ''' def __init__(self, pre: 'Neuron', post: 'Neuron', uid=None, connection_type='chemical-synapse', **kwargs): """ Initializes a new instance of the Connection class. Args: pre (Neuron): The neuron sending the connection. post (Neuron): The neuron receiving the connection. uid (int, optional): The unique identifier for the connection. connection_type (str, optional): The type of the connection. weight (float, optional): The weight of the connection. Must be a numeric value. Raises: ValueError: If weight is not a numeric value. AssertionError: If pre and post neurons are from different networks. """ self.pre = pre self.post = post self.network = post.network # Validate weight weight = kwargs.pop('weight', 1) if not (isinstance(weight, Number) or isinstance(weight, np.number)): raise ValueError("Weight must be a numeric value, but is of type {}".format(type(weight))) self.weight = weight if pre.network != post.network: raise AssertionError("The Nervous Systems of the pre and post neurons must be the same.") if not uid: self.uid = self.network.add_edge(pre, post, weight=self.weight, connection_type=connection_type, **kwargs) else: self.uid = uid self.network.add_edge(pre, post, key=uid, weight=self.weight, connection_type=connection_type, **kwargs) self._id = (pre, post, self.uid) self.connection_type = connection_type self.pre.out_connections[self._id] = self self.post.in_connections[self._id] = self for key, value in kwargs.items(): self.set_property(key, value) @property def by_name(self): """ Returns the connecting neuron names (Pre,Post) """ return (self.pre.name, self.post.name)
[docs] def update_weight(self, weight, delta=False): ''' Sets the connection weight ''' if not delta: self.weight = weight else: self.weight += weight nx.set_edge_attributes(self.network, {self._id:{'weight':self.weight}})
[docs] def set_property(self, key, val): ''' Sets an attribute for the class''' setattr(self, key, val) nx.set_edge_attributes(self.network, {self._id:{key:val}})
[docs] def get_property(self, key): ''' Gets an attribute for the class''' return getattr(self, key)
[docs] class ChemicalSynapse(Connection): ''' This is a convenience class that represents connections of type chemical synapses.''' def __init__(self, pre, post, uid=0, connection_type='chemical-synapse', weight=1, **kwargs): super().__init__(pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs) self.position= kwargs.pop('position', {'AP': 0, 'LR': 0, 'DV': 0})
[docs] class GapJunction(Connection): ''' This is a convenience class that represents connections of type gap junctions.''' def __init__(self, pre, post, uid=1, connection_type='gap-junction', weight=1, **kwargs): super().__init__(pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs) self.position= kwargs.pop('position', {'AP': 0, 'LR': 0, 'DV': 0})
[docs] class BulkConnection(Connection): ''' This is a convenience class that represents connections of type neuropeptide-receptors.''' def __init__(self, pre, post, uid, connection_type, weight=1, **kwargs): super().__init__(pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs)
[docs] class ConnectionGroup: ''' This is a group of connections in the network''' def __init__(self, network, members=None, group_name=None) -> None: """ Initializes a new instance of the ConnectionGroup class. Parameters: groupname (str): The name of the connection group. members (List[str]): The list of members in the connection group. group_id (int, optional): The ID of the neuron group. Defaults to 0. Returns: None """ if group_name is None: self.group_name = 'Group-'+ generate_random_string(8) else: self.group_name = group_name if members is None: members = [] else: assert all([isinstance(m, Connection)for m in members]),\ "Connection group members must be of type Connection" self.members = members self.connections = {m._id:m for m in members} self.network = network self.neurons = set([neuron for m in members for neuron in (m.pre, m.post)]) assert self.group_name not in self.network.groups, \ f"Group name {self.group_name} already exists in the network" self.network.groups.update({self.group_name: self}) def __iter__(self): """ Returns an iterator over the members of the group. """ return iter(self.connections)
[docs] def clear(self): """ Removes all connections from the group. """ self.connections = {} self.members = []
[docs] def items(self): """ Returns the itemized connection dictionary """ return self.connections.items()
[docs] def keys(self): """ Returns the IDs for the Connection Group""" return list(self.connections.keys())
[docs] def values(self): """ Returns the IDs for the Connection Group""" return list(self.connections.values())
def __len__(self): """ Returns the number of members in the group. """ return len(self.connections) def __contains__(self, member): """ Returns True if the connection with the specified name is in the group, False otherwise. """ if TYPE_CHECKING: from .neuron import Neuron if isinstance(member, Neuron): return member in self.neurons return member in self.connections def __getitem__(self, connection_id): """ Returns the connection with the specified name in the group. """ return self.connections[connection_id] def __setitem__(self, connection_id, connection): """ Sets the connection with the specified name in the group. """ assert isinstance(connection, Connection), "Connection must be of type Connection" self.connections[connection_id] = connection
[docs] def update(self, member_dict): """ Updates the list of members in the group. """ assert all([isinstance(connection, Connection) for ename, connection in \ member_dict.items()]), "Connection group members must be\ of type Connection" self.connections.update(member_dict) self.members = list(self.connections.values())
[docs] def pop(self, connection_id): """ Deletes the connection with the specified name from the group. """ self.connections.pop(connection_id)
[docs] def set_property(self, property_name, property_value): """ Sets a new property attribute for all connections in the group. """ for connection in self.members: connection.set_property(property_name, property_value)
[docs] def get_property(self, property_name): """ Gets the property attribute for all connections in the group. """ return [connection.get_property(property_name) for connection in self.members]
[docs] def update_weights(self, weight, delta=False): """ Updates weights for all connections in the group. Args: weight (float): The new weight value or weight delta. delta (bool, optional): If True, weight is added to current weights. \ If False, weight replaces current weights. Raises: ValueError: If weight is not a numeric value. """ if not (isinstance(weight, Number) or isinstance(weight, np.number)): raise ValueError("Weight must be a numeric value, but is of type {}".format(type(weight))) for connection in self.members: connection.update_weight(weight, delta)
[docs] def update_weights_by_function(self, weight_function): """ Updates weights for all connections using a custom function. Args: weight_function (callable): A function that takes a Connection object and returns a new weight value. Raises: ValueError: If weight_function is not callable or returns non-numeric values. """ if not callable(weight_function): raise ValueError("weight_function must be callable") for connection in self.members: new_weight = weight_function(connection) if not (isinstance(new_weight, Number) or isinstance(new_weight, np.number)): raise ValueError("weight_function must return numeric values") connection.update_weight(new_weight)
[docs] def filter_by_type(self, connection_type): """ Returns a new ConnectionGroup containing only connections of the specified type. Args: connection_type (str): The type of connections to filter for. Returns: ConnectionGroup: A new group containing only the filtered connections. """ filtered_members = [m for m in self.members if m.connection_type == connection_type] return ConnectionGroup(self.network, filtered_members, f"{self.group_name}-{connection_type}")
[docs] def filter_by_property(self, property_name, property_value): """ Returns a new ConnectionGroup containing only connections with the specified property value. Args: property_name (str): The name of the property to filter by. property_value: The value to match. Returns: ConnectionGroup: A new group containing only the filtered connections. """ filtered_members = [m for m in self.members if hasattr(m, property_name) and getattr(m, property_name) == property_value] return ConnectionGroup(self.network, filtered_members, f"{self.group_name}-{property_name}-{property_value}")
[docs] def filter_by_function(self, filter_function): """ Returns a new ConnectionGroup containing only connections that pass the filter function. Args: filter_function (callable): A function that takes a Connection object and returns True if the connection should be included. Returns: ConnectionGroup: A new group containing only the filtered connections. """ if not callable(filter_function): raise ValueError("filter_function must be callable") filtered_members = [m for m in self.members if filter_function(m)] return ConnectionGroup(self.network, filtered_members, f"{self.group_name}-filtered")
[docs] def get_statistics(self): """ Returns statistics about the connections in the group. Returns: dict: A dictionary containing statistics about the connections. """ stats = { 'count': len(self.members), 'weight_mean': sum(m.weight for m in self.members) / len(self.members) if self.members else 0, 'weight_min': min(m.weight for m in self.members) if self.members else 0, 'weight_max': max(m.weight for m in self.members) if self.members else 0, 'types': Counter([m.connection_type for m in self.members]) } return stats
[docs] class Path(ConnectionGroup): ''' This is a sequence of Connections in the network.''' def __init__(self, network, members=None, group_name=None): if group_name is None: group_name = 'Path-'+ generate_random_string(8) if members is None: members = [] self.source = None self.target = None else: assert all([isinstance(m, Connection)for m in members]),\ "Path members must be of type Connection" if len(members) > 1: # Check continuity only if there's more than one connection for i in range(len(members)-1): if members[i].post != members[i+1].pre: raise AssertionError("Path members must be continuous connections from source to target") if members: self.source = members[0].pre self.target = members[-1].post else: self.source = None self.target = None super().__init__(network, members, group_name)
[docs] def update(self, member_dict): """ Updates the list of members in the group. """ raise NotImplementedError(f'Cannot update connections in {self.__class__.__name__}')
[docs] def pop(self, connection_id): """ Deletes the connection with the specified name from the group. """ raise NotImplementedError(f'Cannot remove connections from {self.__class__.__name__}')
[docs] def get_length(self): """ Returns the number of connections in the path. Returns: int: The number of connections in the path. """ return len(self.members)
[docs] def get_total_weight(self): """ Returns the sum of weights of all connections in the path. Returns: float: The sum of connection weights. """ return sum(conn.weight for conn in self.members)
[docs] def get_average_weight(self): """ Returns the average weight of connections in the path. Returns: float: The average connection weight. """ if not self.members: return 0.0 return self.get_total_weight() / len(self.members)
[docs] def get_min_weight(self): """ Returns the minimum weight among all connections in the path. Returns: float: The minimum connection weight. """ if not self.members: return 0.0 return min(conn.weight for conn in self.members)
[docs] def get_max_weight(self): """ Returns the maximum weight among all connections in the path. Returns: float: The maximum connection weight. """ if not self.members: return 0.0 return max(conn.weight for conn in self.members)
[docs] def reverse(self): """ Creates a new path with connections in reverse order. Returns: Path: A new path with reversed connections. """ if not self.members: return Path(self.network) reversed_connections = [] for conn in reversed(self.members): # Create a new connection with reversed pre/post neurons reversed_conn = Connection(conn.post, conn.pre, connection_type=conn.connection_type, weight=conn.weight) reversed_connections.append(reversed_conn) return Path(self.network, reversed_connections, f"{self.group_name}-reversed")
[docs] def concatenate(self, other_path): """ Creates a new path by concatenating this path with another path. Args: other_path (Path): The path to concatenate with. Returns: Path: A new path containing all connections from both paths. Raises: AssertionError: If the paths cannot be concatenated (target of first path must match source of second path). """ if not self.members: return Path(self.network, other_path.members.copy()) if not other_path.members: return Path(self.network, self.members.copy()) if self.target != other_path.source: raise AssertionError("Cannot concatenate paths: target of first path must match source of second path") combined_connections = self.members.copy() + other_path.members.copy() return Path(self.network, combined_connections, f"{self.group_name}-{other_path.group_name}")
[docs] def is_valid(self): """ Checks if the path is valid (all connections are continuous). Returns: bool: True if the path is valid, False otherwise. """ if not self.members: return True try: for i in range(len(self.members)-1): if self.members[i].post != self.members[i+1].pre: return False return True except Exception: return False
[docs] def get_neurons(self): """ Returns a list of neurons in the path in order. Returns: list: List of neurons from source to target. """ if not self.members: return [] neurons = [self.source] for conn in self.members: neurons.append(conn.post) return neurons
[docs] def get_connection_types(self): """ Returns a list of connection types in the path in order. Returns: list: List of connection types. """ return [conn.connection_type for conn in self.members]