Source code for cedne.core.neuron

"""
Neuron and cell-level primitives for CeDNe.

This module defines the core data structures for representing cells in the nervous system.
It includes:

- `Cell`: A base class for any biological cell modeled in the nervous system.
- `Neuron`: A subclass of `Cell` specialized for neural structures, supporting
  connectivity, trial-specific recordings, and calcium feature extraction.
- `NeuronGroup`: A container for managing sets of neurons with shared structure,
  metadata, or functional properties.

Neurons are stored within a `NervousSystem` graph, and may maintain their own 
set of incoming and outgoing `Connection` objects. Each neuron can host multiple
`Trial` objects, representing experimental recordings under different conditions.
"""

__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"

import networkx as nx
import copy
from .io import generate_random_string
from typing import Optional, List, Dict, Any, Union, TYPE_CHECKING
from dataclasses import dataclass, field
import numpy as np
from .config import F_SAMPLE
from .recordings import Trial
from .connection import Path

if TYPE_CHECKING:
    from .network import NervousSystem
    from .connection import Path
    from .recordings import Trial

[docs] class Cell: ''' Models a biological cell. ''' def __init__(self, name, network, **kwargs): """ Initializes a new instance of the Cell class. Args: name (str): The name of the neuron. network (NeuronalNetwork): The neuronal network to which the neuron belongs. type (str, optional): The type of the neuron. Defaults to ''. category (str, optional): The category of the neuron. Defaults to ''. modality (str, optional): The modality of the neuron. Defaults to ''. position (dict, optional): The position of the neuron. Defaults to None. presynapses (list, optional): The list of presynaptic components. Defaults to None. postsynapses (dict, optional): The dictionary of postsynaptic components. Defaults to None. """ if not isinstance(name, str): raise TypeError("name must be a string") self.name = name self.group_id = 0 self._data = {} self.network = network # self.type = kwargs.pop('cell_type', '') # self.category= kwargs.pop('category', '') # self.modality= kwargs.pop('modality','') # self.position= kwargs.pop('position', {'AP': 0, 'LR': 0, 'DV': 0}) # self.surface_area = kwargs.pop('surface_area', 1) # self.volume = kwargs.pop('volume', 1) for key, value in kwargs.items(): setattr(self, key, value) self.in_connections = {} self.out_connections = {} self.network.add_node(self, **kwargs)#type=self.type, category=self.category, modality=self.modality)
[docs] class Neuron(Cell): ''' Models a biological neuron''' def __init__(self, name: str, network: 'NervousSystem', **kwargs): """ Initializes a new instance of the Neuron class. Args: name (str): The name of the neuron. network (NervousSystem): The neuronal network to which the neuron belongs. type (str, optional): The type of the neuron. Defaults to ''. category (str, optional): The category of the neuron. Defaults to ''. modality (str, optional): The modality of the neuron. Defaults to ''. position (dict, optional): The position of the neuron. Defaults to None. presynapses (list, optional): The list of presynaptic components. Defaults to None. postsynapses (dict, optional): The dictionary of postsynaptic components. Defaults to None. Raises: ValueError: If a neuron with the given name already exists in the network. """ if name in network.neurons: raise ValueError(f"Neuron with name '{name}' already exists in the network") super().__init__(name, network, **kwargs)#cell_type=neuron_type, category=category, modality=modality,\ position=position) self.network.neurons[name] = self # self.name = name # self.group_id = 0 # self._data = {} # self.network = network # self.type = neuron_type # self.category = category # self.modality = modality # # self.position = position or {'AP': 0, 'LR': 0, 'DV': 0} # self.in_connections = {} # self.out_connections = {} # self.network.add_node(self, type=self.type, category=self.category, modality=self.modality) self.trial = {} self.features = {0: 'Ca_max', 1: 'Ca_area', 2: 'Ca_avg', 3: 'Ca_time_to_peak', 4: 'Ca_area_to_peak', 5: 'Ca_min', 6: 'Ca_onset', 7: 'positive_area', 8: 'positive_time'} # self.presynapse = presynapse or [] # self.postsynapse = postsynapse or {} #self.cable_length = kwargs.pop('cable_length', 1) # def set_presynapse(self, presynapse): # """ # Set the presynapse of the neuron. # Parameters: # presynapse (list): The presynaptic connections of the neuron. # Returns: # None # """ # assert isinstance(presynapse, list), "preSynapse must be a list" # self.presynapse = presynapse # def set_postsynapse(self, postsynapse): # """ # Set the postsynapse of the neuron. # Parameters: # postsynapse (dict): The postsynaptic connections of the neuron. # Key: Receptor name, Value: List of ligand names. # Returns: # None # """ # # postsynapse should be a dictionary where the key is the receptor name and # # the value is a list of ligand names # assert isinstance(postsynapse, dict), ("postSynapse must be a dictionary, " # "where the key is the receptor name " # "and the value is a list of ligand names") # self.postsynapse = postsynapse # {Receptor: ['Ligand_0', 'Ligand_1', ...]}
[docs] def add_trial(self, trial_num=0): """ Adds a new trial to the `trial` dictionary of the current object with the given `trial_num`. If `trial_num` is not provided, it defaults to 0. Returns: Trial: The newly added trial object. """ self.trial[trial_num] = Trial(self, trial_num) return self.trial[trial_num]
[docs] def remove_trial(self, trial_num): """ Removes a trial from the trial dictionary. """ del self.trial[trial_num]
[docs] def get_connections(self, paired_neuron=None, direction='both', connection_type='all'): """ Returns all connections that the neuron is involved in. :return: A list of connections where the neuron is present. :rtype: list """ if connection_type == 'all': if paired_neuron is None: if direction == 'both': return self.in_connections | self.out_connections #return [edge for edge in self.network.edges if self in edge] if direction == 'in': return self.in_connections if direction == 'out': return self.out_connections raise ValueError('Direction must be either "both", "in", or "out"') if paired_neuron is not None: if direction == 'both': return self.outgoing(paired_neuron) | self.incoming(paired_neuron) if direction == 'in': return self.incoming(paired_neuron) if direction == 'out': return self.outgoing(paired_neuron) raise ValueError('Direction must be either "both", "in", or "out"') else: if paired_neuron is None: if direction == 'both': return {key:value for key, value in self.in_connections.items() if value.connection_type == connection_type} | {key:value for key, value in self.out_connections.items() if value.connection_type == connection_type} #return [edge for edge in self.network.edges if self in edge] if direction == 'in': return {key:value for key, value in self.in_connections.items() if value.connection_type == connection_type} if direction == 'out': return {key:value for key, value in self.out_connections.items() if value.connection_type == connection_type} raise ValueError('Direction must be either "both", "in", or "out"') if paired_neuron is not None: if direction == 'both': return {key:value for key, value in self.outgoing(paired_neuron) if value.connection_type == connection_type} | {key:value for key, value in self.incoming(paired_neuron) if value.connection_type == connection_type} if direction == 'in': return {key:value for key, value in self.incoming(paired_neuron) if value.connection_type == connection_type} if direction == 'out': return {key:value for key, value in self.outgoing(paired_neuron) if value.connection_type == connection_type} raise ValueError('Direction must be either "both", "in", or "out"')
[docs] def get_connected_neurons(self, direction='both', weight_filter = 1, connection_type='all'): """ Returns all connected neurons for this neuron. """ if connection_type == 'all': if direction == 'both': conns = self.in_connections | self.out_connections elif direction == 'in': conns = self.in_connections elif direction == 'out': conns = self.out_connections else: raise ValueError('Direction must be either "both", "in", or "out"') all_conns = [] for c, conn in conns.items(): if conn.weight>weight_filter: all_conns+= [c[0]] all_conns+= [c[1]] all_conns = set(all_conns) return all_conns else: if direction == 'both': conns = self.in_connections | self.out_connections elif direction == 'in': conns = self.in_connections elif direction == 'out': conns = self.out_connections else: raise ValueError('Direction must be either "both", "in", or "out"') all_conns = [] for c, conn in conns.items(): if conn.weight>weight_filter and conn.connection_type == connection_type: all_conns+= [c[0]] all_conns+= [c[1]] all_conns = set(all_conns) return all_conns
[docs] def update_connections(self): """ Updates the `in_connections` and `out_connections` dictionaries of the current object. """ self.in_connections = {_id: self.network.connections[_id] for _id in self.network.in_edges(self, keys=True)} self.out_connections = {_id: self.network.connections[_id] for _id in self.network.out_edges(self, keys=True)}
[docs] def outgoing(self, paired_neuron=None): """ Returns a list of all outgoing connections from the current object. :return: A list of connections from the current object to other objects. :rtype: list """ if paired_neuron is None: return self.out_connections if isinstance(paired_neuron, Neuron): return {edge:conn for edge,conn in self.out_connections.items() if edge[0] == self and edge[1] == paired_neuron} raise TypeError('paired_neuron must be a Neuron object')
[docs] def incoming(self, paired_neuron=None): """ Returns a list of all incoming connections to the current object. """ if paired_neuron is None: return self.in_connections if isinstance(paired_neuron, Neuron): return {edge:conn for edge,conn in self.in_connections.items() if edge[1] == self and edge[0] == paired_neuron} raise TypeError('paired_neuron must be a Neuron object')
[docs] def set_property(self, property_name, property_value): """ Sets a new property attribute for the class. Args: property_name (str): The name of the property. property_value: The value of the property. """ setattr(self, property_name, property_value) nx.set_node_attributes(self.network, {self: {property_name: property_value}})
[docs] def get_property(self, key): ''' Gets an attribute for the class''' return getattr(self, key)
[docs] def connects_to(self, other): ''' Checks if this neuron connects to another neuron ''' for o in self.out_connections: if o[1] == other: return True for i in self.in_connections: if i[0] == other: return True return False
[docs] def paths_to(self, target, path_length=1): ''' Returns all paths as a list of connections from this neuron to the target neuron ''' path_list = [self.network.groups[group] for group in self.network.groups if group.startswith(f'Path_{self.name}_{target.name}_length_{path_length}')] paths = nx.all_simple_edge_paths(self.network, self, target, cutoff=path_length) connection_paths = [[self.network.connections[edge] for edge in path] for path in paths] if len(path_list) == len(connection_paths): return path_list else: return [Path(self.network, path, f'Path_{self.name}_{target.name}_length_{path_length}_{j}') for j,path in enumerate(connection_paths)]
[docs] def all_paths(self, path_length=1, direction='both'): ''' Returns all paths as a list of connections from this neuron to all other neurons in the network ''' if direction == 'out': out_paths = [nx.all_simple_edge_paths(self.network, self, self.network.neurons[n], cutoff=path_length) for n in self.network.neurons] connection_paths = [[[self.network.connections[edge] for edge in path] for path in paths] for paths in out_paths] return [Path(self.network, path, f'Path_{self.name}_out_length_{path_length}_{j}_{k}') for k, paths in enumerate(connection_paths) for j,path in enumerate(paths)] elif direction == 'in': in_paths = [nx.all_simple_edge_paths(self.network, self.network.neurons[n], self, cutoff=path_length) for n in self.network.neurons] connection_paths = [[[self.network.connections[edge] for edge in path] for path in paths] for paths in in_paths] return [Path(self.network, path, f'Path_{self.name}_in_length_{path_length}_{j}_{k}') for k, paths in enumerate(connection_paths) for j,path in enumerate(paths)] elif direction=='both': in_paths = [nx.all_simple_edge_paths(self.network, self.network.neurons[n], self, cutoff=path_length) for n in self.network.neurons] out_paths = [nx.all_simple_edge_paths(self.network, self, self.network.neurons[n], cutoff=path_length) for n in self.network.neurons] connection_paths_out = [[[self.network.connections[edge] for edge in path] for path in paths] for paths in out_paths] connection_paths_in = [[[self.network.connections[edge] for edge in path] for path in paths] for paths in in_paths] return [Path(self.network, path, f'Path_{self.name}_out_length_{path_length}_{j}_{k}') for k, paths in enumerate(connection_paths_out) for j,path in enumerate(paths)] + [Path(self.network, path, f'Path_{self.name}_in_length_{path_length}_{j}_{k}') for k, paths in enumerate(connection_paths_in) for j,path in enumerate(paths)]
def __str__(self): ## For use in debugging and testing return self.name
# def __repr__(self): # ## For use in debugging and testing # return self.name
[docs] class NeuronGroup: ''' This contains a group of neurons in the network''' def __init__(self, network, members=None, group_name=None) -> None: """ Initializes a new instance of the NeuronGroup class. Parameters: groupname (str): The name of the neuron group. members (List[str]): The list of members in the neuron group. group_id (int, optional): The ID of the neuron group. Defaults to 0. Returns: None """ if group_name is None: self.group_name = 'Group-'+ generate_random_string(8) else: self.group_name = group_name if members is None: members = [] else: assert all([isinstance(m, Neuron)for m in members]), "Neuron group members must be\ of type Neuron" self.members = members self.neurons = {m.name: m for m in members} self.network = network assert self.group_name not in self.network.groups, f"Group name {self.group_name}\ already exists in the network" self.network.groups.update({self.group_name: self}) def __iter__(self): """ Returns an iterator over the members of the group. """ return iter(self.neurons)
[docs] def items(self): """ Returns an iterator over the members of the group. """ for key, value in self.neurons.items(): yield key, value
[docs] def keys(self): """ Returns an iterator over the members of the group. """ return list(self.neurons.keys())
[docs] def values(self): """ Returns an iterator over the members of the group. """ return list(self.neurons.values())
def __len__(self): """ Returns the number of members in the group. """ return len(self.neurons) def __contains__(self, neuron): """ Returns True if the neuron with the specified name is in the group, False otherwise. """ return neuron in self.neurons def __getitem__(self, neuron_name): """ Returns the neuron with the specified name in the group. """ return self.neurons[neuron_name] def __setitem__(self, neuron_name, neuron): """ Sets the neuron with the specified name in the group. """ assert isinstance(neuron, Neuron), "Neuron group members must be of type Neuron" self.neurons[neuron_name] = neuron
[docs] def clear(self): """ Removes all neurons from the group. """ self.neurons = {} self.members = []
[docs] def update(self, member_dict): """ Updates the list of members in the group. """ assert all([isinstance(neuron, Neuron) for nname,neuron in member_dict.items()]),\ "Neuron group members must be of type Neuron" self.neurons.update(member_dict) self.members = list(self.neurons.values())
[docs] def pop(self, neuron_name): """ Deletes the neuron with the specified name from the group. """ self.neurons.pop(neuron_name)
[docs] def set_property(self, property_name, property_value): """ Sets a new property attribute for all neurons in the group. """ for neuron in self.members: neuron.set_property(property_name, property_value)
[docs] def get_property(self, property_name): """ Returns the value of the specified property for all neurons in the group. """ return [neuron.get_property(property_name) for neuron in self.members]
[docs] def get_connections(self): """ Returns a list of all connections in the group. """ return [neuron.get_connections() for neuron in self.members]
[docs] def add_neuron(self, neuron: 'Neuron') -> None: """Add a neuron to the group. Args: neuron: Neuron to add. """ if neuron not in self.neurons: self.neurons[neuron.name] = neuron
[docs] def remove_neuron(self, neuron: 'Neuron') -> None: """Remove a neuron from the group. Args: neuron: Neuron to remove. """ if neuron in self.neurons: self.neurons.pop(neuron.name)
[docs] def get_neurons_by_type(self, type: str) -> List['Neuron']: """Get all neurons of a specific type. Args: type: Neuron type to filter by. Returns: List of neurons of the specified type. """ return [n for n in self.neurons.values() if n.type == type]
[docs] def get_neurons_by_property(self, key: str, value: Any) -> List['Neuron']: """Get neurons with a specific property value. Args: key: Property name. value: Property value to match. Returns: List of neurons with matching property value. """ return [n for n in self.neurons.values() if n.get_property(key) == value]