"""
Connection primitives for CeDNe.
This module defines the core data structures for representing connections between cells
in the nervous system. It includes:
- `Connection`: Base class for all types of connections.
- `ChemicalSynapse`: Specialized connection for chemical synapses.
- `GapJunction`: Specialized connection for gap junctions.
- `BulkConnection`: Specialized connection for bulk connections.
- `ConnectionGroup`: Container for managing sets of connections.
- `Path`: A sequence of connections between cells.
Each connection type can maintain its own properties and weights while sharing common
functionality through the base Connection class.
"""
__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"
import networkx as nx
from collections import Counter
import numpy as np
from .io import generate_random_string
from .source import Citable, serialize_citations
from typing import TYPE_CHECKING
from numbers import Number
if TYPE_CHECKING:
from .neuron import Neuron
[docs]
class Connection(Citable):
"""This class represents a connection between two cells."""
def __init__(
self,
pre: "Neuron",
post: "Neuron",
uid=None,
connection_type="chemical-synapse",
**kwargs,
):
"""
Initializes a new instance of the Connection class.
Args:
pre (Neuron):
The neuron sending the connection.
post (Neuron):
The neuron receiving the connection.
uid (int, optional):
The unique identifier for the connection.
connection_type (str, optional):
The type of the connection.
weight (float, optional):
The weight of the connection. Must be a numeric value.
Raises:
ValueError: If weight is not a numeric value.
AssertionError: If pre and post neurons are from different networks.
"""
Citable.__init__(self) # provides self.citations = {}
self.pre = pre
self.post = post
self.network = post.network
# Validate weight
weight = kwargs.pop("weight", 1)
if not (isinstance(weight, Number) or isinstance(weight, np.number)):
raise ValueError(
"Weight must be a numeric value, but is of type {}".format(type(weight))
)
self.weight = weight
if pre.network != post.network:
raise AssertionError(
"The Nervous Systems of the pre and post neurons must be the same."
)
if not uid:
self.uid = self.network.add_edge(
pre, post, weight=self.weight, connection_type=connection_type, **kwargs
)
else:
self.uid = uid
self.network.add_edge(
pre,
post,
key=uid,
weight=self.weight,
connection_type=connection_type,
**kwargs,
)
self._id = (pre, post, self.uid)
self.connection_type = connection_type
self.pre.out_connections[self._id] = self
self.post.in_connections[self._id] = self
for key, value in kwargs.items():
self.set_property(key, value)
def _parent_citables(self):
"""Walk citations up through containing ConnectionGroups and the NervousSystem."""
parents = []
if self.network is not None:
for group in self.network.groups.values():
# O(1) membership via the _id->Connection dict each ConnectionGroup maintains
if (
isinstance(group, ConnectionGroup)
and group.connections.get(self._id) is self
):
parents.append(group)
parents.append(self.network)
return parents
@property
def by_name(self):
"""
Returns the connecting neuron names (Pre,Post)
"""
return (self.pre.name, self.post.name)
[docs]
def update_weight(self, weight, delta=False):
"""Sets the connection weight"""
if not delta:
self.weight = weight
else:
self.weight += weight
nx.set_edge_attributes(self.network, {self._id: {"weight": self.weight}})
[docs]
def set_property(self, key, val):
"""Sets an attribute for the class"""
setattr(self, key, val)
nx.set_edge_attributes(self.network, {self._id: {key: val}})
[docs]
def get_property(self, key):
"""Gets an attribute for the class"""
return getattr(self, key)
[docs]
def to_dict(self) -> dict:
"""Serialize connection to a plain Python dictionary.
Returns a dict with guaranteed keys: ``source``, ``target``,
``weight``, ``type``. Optionally includes ``ligands``,
``neurotransmitters``, and ``receptors`` if present.
Returns:
dict: A JSON-compatible dictionary representation.
"""
d = {
"source": self.pre.name,
"target": self.post.name,
"weight": self.weight,
"type": self.connection_type,
}
# Semantic edge metadata
for attr in ("ligands", "neurotransmitters"):
val = getattr(self, attr, None)
if val:
d[attr] = val
if hasattr(self, "receptors") and self.receptors:
d["receptors"] = (
list(self.receptors.keys())
if isinstance(self.receptors, dict)
else self.receptors
)
# Citations attached directly to this connection (not the inherited chain)
if hasattr(self, "citations") and self.citations:
d["citations"] = serialize_citations(self.citations)
return d
[docs]
class ChemicalSynapse(Connection):
"""This is a convenience class that represents connections of type chemical synapses."""
def __init__(
self, pre, post, uid=0, connection_type="chemical-synapse", weight=1, **kwargs
):
super().__init__(
pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs
)
self.position = kwargs.pop("position", {"AP": 0, "LR": 0, "DV": 0})
[docs]
class GapJunction(Connection):
"""This is a convenience class that represents connections of type gap junctions."""
def __init__(
self, pre, post, uid=1, connection_type="gap-junction", weight=1, **kwargs
):
super().__init__(
pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs
)
self.position = kwargs.pop("position", {"AP": 0, "LR": 0, "DV": 0})
[docs]
class BulkConnection(Connection):
"""This is a convenience class that represents connections of type neuropeptide-receptors."""
def __init__(self, pre, post, uid, connection_type, weight=1, **kwargs):
super().__init__(
pre, post, uid=uid, connection_type=connection_type, weight=weight, **kwargs
)
[docs]
class ConnectionGroup(Citable):
"""This is a group of connections in the network"""
def __init__(self, network, members=None, group_name=None) -> None:
"""
Initializes a new instance of the ConnectionGroup class.
Parameters:
network (NervousSystem):
The owning nervous system.
members (Iterable[Connection], optional):
The connections in the group. Each entry must be a
``Connection`` instance — connection identity is the
``(pre, post, uid)`` triple, not a string. Defaults to empty.
group_name (str, optional):
The name of the connection group. Auto-generated if omitted.
Returns:
None
"""
Citable.__init__(self) # provides self.citations = {}
if group_name is None:
self.group_name = "Group-" + generate_random_string(8)
else:
self.group_name = group_name
if members is None:
members = []
else:
assert all(
[isinstance(m, Connection) for m in members]
), "Connection group members must be of type Connection"
self.members = members
self.connections = {m._id: m for m in members}
self.network = network
self.neurons = set([neuron for m in members for neuron in (m.pre, m.post)])
assert (
self.group_name not in self.network.groups
), f"Group name {self.group_name} already exists in the network"
self.network.groups.update({self.group_name: self})
def _parent_citables(self):
"""Walk citations up to the containing NervousSystem."""
return (self.network,) if self.network is not None else ()
def __iter__(self):
"""
Returns an iterator over the members of the group.
"""
return iter(self.connections)
[docs]
def clear(self):
"""
Removes all connections from the group.
"""
self.connections = {}
self.members = []
[docs]
def items(self):
"""
Returns the itemized connection dictionary
"""
return self.connections.items()
[docs]
def keys(self):
"""Returns the IDs for the Connection Group"""
return list(self.connections.keys())
[docs]
def values(self):
"""Returns the IDs for the Connection Group"""
return list(self.connections.values())
def __len__(self):
"""
Returns the number of members in the group.
"""
return len(self.connections)
def __contains__(self, member):
"""
Returns True if the connection with the specified name is in the group, False otherwise.
"""
if TYPE_CHECKING:
from .neuron import Neuron
if isinstance(member, Neuron):
return member in self.neurons
return member in self.connections
def __getitem__(self, connection_id):
"""
Returns the connection with the specified name in the group.
"""
return self.connections[connection_id]
def __setitem__(self, connection_id, connection):
"""
Sets the connection with the specified name in the group.
"""
assert isinstance(
connection, Connection
), "Connection must be of type Connection"
self.connections[connection_id] = connection
[docs]
def update(self, member_dict):
"""
Updates the list of members in the group.
"""
assert all(
[
isinstance(connection, Connection)
for ename, connection in member_dict.items()
]
), "Connection group members must be\
of type Connection"
self.connections.update(member_dict)
self.members = list(self.connections.values())
[docs]
def pop(self, connection_id):
"""
Deletes the connection with the specified name from the group.
"""
self.connections.pop(connection_id)
[docs]
def set_property(self, property_name, property_value):
"""
Sets a new property attribute for all connections in the group.
"""
for connection in self.members:
connection.set_property(property_name, property_value)
[docs]
def get_property(self, property_name):
"""
Gets the property attribute for all connections in the group.
"""
return [connection.get_property(property_name) for connection in self.members]
[docs]
def update_weights(self, weight, delta=False):
"""
Updates weights for all connections in the group.
Args:
weight (float): The new weight value or weight delta.
delta (bool, optional): If True, weight is added to current weights. \
If False, weight replaces current weights.
Raises:
ValueError: If weight is not a numeric value.
"""
if not (isinstance(weight, Number) or isinstance(weight, np.number)):
raise ValueError(
"Weight must be a numeric value, but is of type {}".format(type(weight))
)
for connection in self.members:
connection.update_weight(weight, delta)
[docs]
def update_weights_by_function(self, weight_function):
"""
Updates weights for all connections using a custom function.
Args:
weight_function (callable): A function that takes a Connection object
and returns a new weight value.
Raises:
ValueError: If weight_function is not callable or returns non-numeric values.
"""
if not callable(weight_function):
raise ValueError("weight_function must be callable")
for connection in self.members:
new_weight = weight_function(connection)
if not (
isinstance(new_weight, Number) or isinstance(new_weight, np.number)
):
raise ValueError("weight_function must return numeric values")
connection.update_weight(new_weight)
[docs]
def filter_by_type(self, connection_type):
"""
Returns a new ConnectionGroup containing only connections of the specified type.
Args:
connection_type (str): The type of connections to filter for.
Returns:
ConnectionGroup: A new group containing only the filtered connections.
"""
filtered_members = [
m for m in self.members if m.connection_type == connection_type
]
return ConnectionGroup(
self.network, filtered_members, f"{self.group_name}-{connection_type}"
)
[docs]
def filter_by_property(self, property_name, property_value):
"""
Returns a new ConnectionGroup containing only connections with the specified property value.
Args:
property_name (str): The name of the property to filter by.
property_value: The value to match.
Returns:
ConnectionGroup: A new group containing only the filtered connections.
"""
filtered_members = [
m
for m in self.members
if hasattr(m, property_name) and getattr(m, property_name) == property_value
]
return ConnectionGroup(
self.network,
filtered_members,
f"{self.group_name}-{property_name}-{property_value}",
)
[docs]
def filter_by_function(self, filter_function):
"""
Returns a new ConnectionGroup containing only connections that pass the filter function.
Args:
filter_function (callable): A function that takes a Connection object
and returns True if the connection should be included.
Returns:
ConnectionGroup: A new group containing only the filtered connections.
"""
if not callable(filter_function):
raise ValueError("filter_function must be callable")
filtered_members = [m for m in self.members if filter_function(m)]
return ConnectionGroup(
self.network, filtered_members, f"{self.group_name}-filtered"
)
[docs]
def get_statistics(self):
"""
Returns statistics about the connections in the group.
Returns:
dict: A dictionary containing statistics about the connections.
"""
stats = {
"count": len(self.members),
"weight_mean": sum(m.weight for m in self.members) / len(self.members)
if self.members
else 0,
"weight_min": min(m.weight for m in self.members) if self.members else 0,
"weight_max": max(m.weight for m in self.members) if self.members else 0,
"types": Counter([m.connection_type for m in self.members]),
}
return stats
# ---- Set operations ----
[docs]
def union(
self, other: "ConnectionGroup", group_name: str = None
) -> "ConnectionGroup":
"""Return a new ConnectionGroup containing connections from both groups.
Args:
other: Another ConnectionGroup to union with.
group_name: Optional name for the resulting group.
Returns:
A new ConnectionGroup with the combined members.
"""
assert isinstance(other, ConnectionGroup), "Operand must be a ConnectionGroup"
assert (
self.network is other.network
), "Both groups must belong to the same network"
merged = {**self.connections, **other.connections}
name = group_name or f"{self.group_name}_union_{other.group_name}"
return ConnectionGroup(
self.network, members=list(merged.values()), group_name=name
)
[docs]
def intersection(
self, other: "ConnectionGroup", group_name: str = None
) -> "ConnectionGroup":
"""Return a new ConnectionGroup containing only connections present in both groups.
Args:
other: Another ConnectionGroup to intersect with.
group_name: Optional name for the resulting group.
Returns:
A new ConnectionGroup with the shared members.
"""
assert isinstance(other, ConnectionGroup), "Operand must be a ConnectionGroup"
assert (
self.network is other.network
), "Both groups must belong to the same network"
common_keys = set(self.connections.keys()) & set(other.connections.keys())
members = [self.connections[k] for k in common_keys]
name = group_name or f"{self.group_name}_intersect_{other.group_name}"
return ConnectionGroup(self.network, members=members, group_name=name)
[docs]
def difference(
self, other: "ConnectionGroup", group_name: str = None
) -> "ConnectionGroup":
"""Return a new ConnectionGroup containing connections in self but not in other.
Args:
other: Another ConnectionGroup to subtract.
group_name: Optional name for the resulting group.
Returns:
A new ConnectionGroup with the difference members.
"""
assert isinstance(other, ConnectionGroup), "Operand must be a ConnectionGroup"
assert (
self.network is other.network
), "Both groups must belong to the same network"
diff_keys = set(self.connections.keys()) - set(other.connections.keys())
members = [self.connections[k] for k in diff_keys]
name = group_name or f"{self.group_name}_diff_{other.group_name}"
return ConnectionGroup(self.network, members=members, group_name=name)
def __or__(self, other: "ConnectionGroup") -> "ConnectionGroup":
"""Operator ``|`` — union."""
return self.union(other)
def __and__(self, other: "ConnectionGroup") -> "ConnectionGroup":
"""Operator ``&`` — intersection."""
return self.intersection(other)
def __sub__(self, other: "ConnectionGroup") -> "ConnectionGroup":
"""Operator ``-`` — difference."""
return self.difference(other)
[docs]
class Path(ConnectionGroup):
"""This is a sequence of Connections in the network."""
def __init__(self, network, members=None, group_name=None):
if group_name is None:
group_name = "Path-" + generate_random_string(8)
if members is None:
members = []
self.source = None
self.target = None
else:
assert all(
[isinstance(m, Connection) for m in members]
), "Path members must be of type Connection"
if len(members) > 1:
# Check continuity only if there's more than one connection
for i in range(len(members) - 1):
if members[i].post != members[i + 1].pre:
raise AssertionError(
"Path members must be continuous connections from source to target"
)
if members:
self.source = members[0].pre
self.target = members[-1].post
else:
self.source = None
self.target = None
super().__init__(network, members, group_name)
[docs]
def update(self, member_dict):
"""
Updates the list of members in the group.
"""
raise NotImplementedError(
f"Cannot update connections in {self.__class__.__name__}"
)
[docs]
def pop(self, connection_id):
"""
Deletes the connection with the specified name from the group.
"""
raise NotImplementedError(
f"Cannot remove connections from {self.__class__.__name__}"
)
[docs]
def get_length(self):
"""
Returns the number of connections in the path.
Returns:
int: The number of connections in the path.
"""
return len(self.members)
[docs]
def get_total_weight(self):
"""
Returns the sum of weights of all connections in the path.
Returns:
float: The sum of connection weights.
"""
return sum(conn.weight for conn in self.members)
[docs]
def get_average_weight(self):
"""
Returns the average weight of connections in the path.
Returns:
float: The average connection weight.
"""
if not self.members:
return 0.0
return self.get_total_weight() / len(self.members)
[docs]
def get_min_weight(self):
"""
Returns the minimum weight among all connections in the path.
Returns:
float: The minimum connection weight.
"""
if not self.members:
return 0.0
return min(conn.weight for conn in self.members)
[docs]
def get_max_weight(self):
"""
Returns the maximum weight among all connections in the path.
Returns:
float: The maximum connection weight.
"""
if not self.members:
return 0.0
return max(conn.weight for conn in self.members)
[docs]
def reverse(self):
"""
Creates a new path with connections in reverse order.
Returns:
Path: A new path with reversed connections.
"""
if not self.members:
return Path(self.network)
reversed_connections = []
for conn in reversed(self.members):
# Create a new connection with reversed pre/post neurons
reversed_conn = Connection(
conn.post,
conn.pre,
connection_type=conn.connection_type,
weight=conn.weight,
)
reversed_connections.append(reversed_conn)
return Path(self.network, reversed_connections, f"{self.group_name}-reversed")
[docs]
def concatenate(self, other_path):
"""
Creates a new path by concatenating this path with another path.
Args:
other_path (Path): The path to concatenate with.
Returns:
Path: A new path containing all connections from both paths.
Raises:
AssertionError: If the paths cannot be concatenated (target of first path
must match source of second path).
"""
if not self.members:
return Path(self.network, other_path.members.copy())
if not other_path.members:
return Path(self.network, self.members.copy())
if self.target != other_path.source:
raise AssertionError(
"Cannot concatenate paths: target of first path must match source of second path"
)
combined_connections = self.members.copy() + other_path.members.copy()
return Path(
self.network,
combined_connections,
f"{self.group_name}-{other_path.group_name}",
)
[docs]
def is_valid(self):
"""
Checks if the path is valid (all connections are continuous).
Returns:
bool: True if the path is valid, False otherwise.
"""
if not self.members:
return True
try:
for i in range(len(self.members) - 1):
if self.members[i].post != self.members[i + 1].pre:
return False
return True
except Exception:
return False
[docs]
def get_neurons(self):
"""
Returns a list of neurons in the path in order.
Returns:
list: List of neurons from source to target.
"""
if not self.members:
return []
neurons = [self.source]
for conn in self.members:
neurons.append(conn.post)
return neurons
[docs]
def get_connection_types(self):
"""
Returns a list of connection types in the path in order.
Returns:
list: List of connection types.
"""
return [conn.connection_type for conn in self.members]