Source code for cedne.core.neuron

"""
Neuron and cell-level primitives for CeDNe.

This module defines the core data structures for representing cells in the nervous system.
It includes:

- `Cell`: A base class for any biological cell modeled in the nervous system.
- `Neuron`: A subclass of `Cell` specialized for neural structures, supporting
  connectivity, trial-specific recordings, and calcium feature extraction.
- `NeuronGroup`: A container for managing sets of neurons with shared structure,
  metadata, or functional properties.

Neurons are stored within a `NervousSystem` graph, and may maintain their own
set of incoming and outgoing `Connection` objects. Each neuron can host multiple
`Trial` objects, representing experimental recordings under different conditions.
"""

__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"

import math
import networkx as nx
from .io import generate_random_string
from typing import List, Dict, Any, TYPE_CHECKING
import numpy as np
from .recordings import Trial
from .connection import Path
from .source import Citable, serialize_citations

if TYPE_CHECKING:
    from .network import NervousSystem
    from .connection import Path
    from .recordings import Trial

"""Sentinel value assigned to enum-like neuron attributes
(``type``, ``category``, ``modality``, …) when the constituents of a
merged neuron disagree on that attribute. Analyses that switch on those
attributes should branch on this value explicitly rather than silently
treating a merged neuron as its first constituent's value.
"""
MERGED_TYPE = "merged"

# Enum-like attributes whose merge follows the "all-same → keep / mixed
# → 'merged'" policy in ``NervousSystem.contract_neurons``. Adding an
# attribute here automatically:
#   * captures it in each constituent's snapshot under
#     ``Neuron.constituents[name][attr]``,
#   * applies the merge policy to the surviving neuron,
#   * mirrors the result onto the networkx node attribute dict so paths
#     that reconstruct neurons via ``create_neurons_from(data=True)``
#     propagate it.
# The list is intentionally short — only attributes that behave like
# bounded categorical labels. Numeric properties (degree, length,
# transcript counts) need their own aggregation policy and are not
# covered here.
MERGE_TRACK_ATTRS = ("type", "category", "modality")


[docs] class Cell: """ Models a biological cell. """ def __init__(self, name, network, **kwargs): """ Initializes a new instance of the Cell class. Args: name (str): The name of the neuron. network (NeuronalNetwork): The neuronal network to which the neuron belongs. type (str, optional): The type of the neuron. Defaults to ''. category (str, optional): The category of the neuron. Defaults to ''. modality (str, optional): The modality of the neuron. Defaults to ''. position (dict, optional): The position of the neuron. Defaults to None. presynapses (list, optional): The list of presynaptic components. Defaults to None. postsynapses (dict, optional): The dictionary of postsynaptic components. Defaults to None. """ if not isinstance(name, str): raise TypeError("name must be a string") self.name = name self.group_id = 0 self._data = {} self.network = network # self.type = kwargs.pop('cell_type', '') # self.category= kwargs.pop('category', '') # self.modality= kwargs.pop('modality','') # self.position= kwargs.pop('position', {'AP': 0, 'LR': 0, 'DV': 0}) # self.surface_area = kwargs.pop('surface_area', 1) # self.volume = kwargs.pop('volume', 1) for key, value in kwargs.items(): setattr(self, key, value) self.in_connections = {} self.out_connections = {} self.network.add_node( self, **kwargs ) # type=self.type, category=self.category, modality=self.modality)
[docs] class Neuron(Cell, Citable): """Models a biological neuron""" # Attributes excluded from automatic scalar introspection in to_dict() _SERIALIZE_SKIP = frozenset( { "name", "network", "in_connections", "out_connections", "trial", "features", "spatial_mask", "_data", "loadings", "position", "transcript", "group_id", "citations", # `constituents` is a dict (not a numeric scalar) — handled # explicitly in to_dict() via the merged-neuron block. "constituents", } ) def __init__(self, name: str, network: "NervousSystem", **kwargs): """ Initializes a new instance of the Neuron class. Args: name (str): The name of the neuron. network (NervousSystem): The neuronal network to which the neuron belongs. type (str, optional): The type of the neuron. Defaults to ''. category (str, optional): The category of the neuron. Defaults to ''. modality (str, optional): The modality of the neuron. Defaults to ''. position (dict, optional): The position of the neuron. Defaults to None. presynapses (list, optional): The list of presynaptic components. Defaults to None. postsynapses (dict, optional): The dictionary of postsynaptic components. Defaults to None. Raises: ValueError: If a neuron with the given name already exists in the network. """ if name in network.neurons: raise ValueError(f"Neuron with name '{name}' already exists in the network") Cell.__init__(self, name, network, **kwargs) Citable.__init__(self) # provides self.citations = {} self.network.neurons[name] = self # self.name = name # self.group_id = 0 # self._data = {} # self.network = network # self.type = neuron_type # self.category = category # self.modality = modality # # self.position = position or {'AP': 0, 'LR': 0, 'DV': 0} # self.in_connections = {} # self.out_connections = {} # self.network.add_node(self, type=self.type, category=self.category, modality=self.modality) self.trial = kwargs.pop("trial", {}) self.features = kwargs.pop( "features", { 0: "Ca_max", 1: "Ca_area", 2: "Ca_avg", 3: "Ca_time_to_peak", 4: "Ca_area_to_peak", 5: "Ca_min", 6: "Ca_onset", 7: "positive_area", 8: "positive_time", }, ) # Loadings from dimensionality reduction (SVD/PCA/NMF) # e.g. {"SVD_PC1": 0.45, "SVD_PC2": -0.12, "SVD_PC3": 0.03} self.loadings = kwargs.pop("loadings", {}) # self.presynapse = presynapse or [] # self.postsynapse = postsynapse or {} self.spatial_mask = kwargs.pop("spatial_mask", None) # self.cable_length = kwargs.pop('cable_length', 1) @property def is_merged(self) -> bool: """True iff this neuron was produced by contracting other neurons. Backed by ``self.constituents``: a non-empty dict ⇒ merged. Set by ``NervousSystem.contract_neurons``; round-trips through graph cloning via the networkx node attribute mirror. """ return bool(getattr(self, "constituents", None)) def _constituent_values(self, attr): """Sorted list of distinct, non-empty string values of `attr` across this neuron's constituents. Foundation for the per-attr derived properties below — kept as one helper so the policy lives in one place. Non-string values (``None``, ``NaN`` from upstream pandas loaders, numeric junk) are dropped: these enum-like attributes only carry meaningful information when they're string labels, and mixed-type sets break ``sorted`` on Python 3 anyway. """ if not getattr(self, "constituents", None): return [] values = set() for meta in self.constituents.values(): v = meta.get(attr) if isinstance(v, str) and v: values.add(v) return sorted(values) @property def constituent_types(self): """Sorted list of distinct source types across this neuron's constituents, derived on-demand from ``self.constituents``. Single source of truth: changing ``constituents`` automatically updates this list — no parallel field to keep in sync. Returns ``[]`` for un-merged neurons. """ return self._constituent_values("type") @property def constituent_categories(self): """Sorted list of distinct source categories across this neuron's constituents. ``[]`` for un-merged neurons or for merged neurons whose constituents had no category set. """ return self._constituent_values("category") @property def constituent_modalities(self): """Sorted list of distinct source modalities across this neuron's constituents. ``[]`` for un-merged neurons or for merged neurons whose constituents had no modality set. """ return self._constituent_values("modality") def _parent_citables(self): """Walk citations up through containing NeuronGroups and the NervousSystem.""" parents = [] if self.network is not None: for group in self.network.groups.values(): # O(1) membership via the name->Neuron dict each NeuronGroup maintains if ( isinstance(group, NeuronGroup) and group.neurons.get(self.name) is self ): parents.append(group) parents.append(self.network) return parents
[docs] def to_dict(self) -> Dict[str, Any]: """Serialize neuron to a plain Python dictionary. Returns a dict with guaranteed keys: ``id``, ``type``, ``group_id``, ``degree``, ``has_recordings``. Optionally includes ``position``, ``scalars`` (auto-discovered numeric attributes), and ``loadings`` (dimensionality-reduction weights). Returns: Dict[str, Any]: A JSON-compatible dictionary representation. """ d: Dict[str, Any] = { "id": self.name, "type": getattr(self, "type", "Unknown"), "group_id": getattr(self, "group_id", 0), "degree": len(self.in_connections) + len(self.out_connections), "has_recordings": bool(self.trial), } # Position — serialize as-is (no LR string normalization) if hasattr(self, "position") and self.position: if isinstance(self.position, dict): d["position"] = dict(self.position) # Numeric scalar attributes (dynamic introspection) scalars: Dict[str, Any] = {} for attr in vars(self): if attr.startswith("_") or attr in self._SERIALIZE_SKIP: continue try: val = getattr(self, attr) if isinstance(val, (int, float)) and math.isfinite(val): scalars[attr] = val elif ( isinstance(val, (list, tuple)) and val and all(isinstance(v, (int, float)) for v in val) ): scalars[attr] = list(val) except Exception: pass if scalars: d["scalars"] = scalars # Loadings from dimensionality reduction if hasattr(self, "loadings") and self.loadings: clean_loadings: Dict[str, float] = {} for k, v in self.loadings.items(): try: fv = float(v) if math.isfinite(fv): clean_loadings[k] = fv except (ValueError, TypeError): pass if clean_loadings: d["loadings"] = clean_loadings # Citations attached directly to this neuron (not the inherited chain) if hasattr(self, "citations") and self.citations: d["citations"] = serialize_citations(self.citations) # Merge provenance — `is_merged` is only emitted when actually # merged so un-merged neurons stay payload-slim. Each # constituent snapshot includes every MERGE_TRACK_ATTRS value # captured at merge time. if self.is_merged: d["is_merged"] = True d["constituents"] = [ { "name": meta.get("name", name), **{a: meta.get(a, "") for a in MERGE_TRACK_ATTRS}, } for name, meta in self.constituents.items() ] d["constituent_types"] = self.constituent_types cats = self.constituent_categories if cats: d["constituent_categories"] = cats mods = self.constituent_modalities if mods: d["constituent_modalities"] = mods return d
# def set_presynapse(self, presynapse): # """ # Set the presynapse of the neuron. # Parameters: # presynapse (list): The presynaptic connections of the neuron. # Returns: # None # """ # assert isinstance(presynapse, list), "preSynapse must be a list" # self.presynapse = presynapse # def set_postsynapse(self, postsynapse): # """ # Set the postsynapse of the neuron. # Parameters: # postsynapse (dict): The postsynaptic connections of the neuron. # Key: Receptor name, Value: List of ligand names. # Returns: # None # """ # # postsynapse should be a dictionary where the key is the receptor name and # # the value is a list of ligand names # assert isinstance(postsynapse, dict), ("postSynapse must be a dictionary, " # "where the key is the receptor name " # "and the value is a list of ligand names") # self.postsynapse = postsynapse # {Receptor: ['Ligand_0', 'Ligand_1', ...]}
[docs] def add_trial(self, trial_num=0): """ Adds a new trial to the `trial` dictionary of the current object with the given `trial_num`. If `trial_num` is not provided, it defaults to 0. Returns: Trial: The newly added trial object. """ self.trial[trial_num] = Trial(self, trial_num) # B4: Sync to node attributes so **data works in subnetwork/copy nx.set_node_attributes(self.network, {self: {"trial": self.trial}}) return self.trial[trial_num]
[docs] def load_recording(self, data, trial_num=0, sampling_rate=None, metadata=None): """Attach a 1D time-series recording to this neuron as a new Trial. Thin convenience over `add_trial()` + setting `trial.recording`; use it when you already have a per-neuron trace in memory (ndarray, list, pandas Series) and don't want to manage Trial objects by hand. Args: data: 1D array-like of samples (calcium trace, voltage, etc.). trial_num: Trial index (default 0). Overwrites any existing trial at this index for this neuron. sampling_rate: Sampling rate in Hz; written to ``trial.metadata``. Leave as None to keep the module default (F_SAMPLE). metadata: Extra metadata dict merged into ``trial.metadata``. Returns: Trial: The trial object holding the recording. """ trial = self.add_trial(trial_num) trial.recording = np.asarray(data) if sampling_rate is not None: trial.metadata["sampling_rate"] = float(sampling_rate) if metadata: trial.metadata.update(metadata) return trial
[docs] def remove_trial(self, trial_num): """ Removes a trial from the trial dictionary. """ del self.trial[trial_num]
[docs] def get_connections( self, paired_neuron=None, direction="both", connection_type="all" ): """ Returns all connections that the neuron is involved in. :return: A list of connections where the neuron is present. :rtype: list """ if connection_type == "all": if paired_neuron is None: if direction == "both": return self.in_connections | self.out_connections # return [edge for edge in self.network.edges if self in edge] if direction == "in": return self.in_connections if direction == "out": return self.out_connections raise ValueError('Direction must be either "both", "in", or "out"') if paired_neuron is not None: if direction == "both": return self.outgoing(paired_neuron) | self.incoming(paired_neuron) if direction == "in": return self.incoming(paired_neuron) if direction == "out": return self.outgoing(paired_neuron) raise ValueError('Direction must be either "both", "in", or "out"') else: if paired_neuron is None: if direction == "both": return { key: value for key, value in self.in_connections.items() if value.connection_type == connection_type } | { key: value for key, value in self.out_connections.items() if value.connection_type == connection_type } # return [edge for edge in self.network.edges if self in edge] if direction == "in": return { key: value for key, value in self.in_connections.items() if value.connection_type == connection_type } if direction == "out": return { key: value for key, value in self.out_connections.items() if value.connection_type == connection_type } raise ValueError('Direction must be either "both", "in", or "out"') if paired_neuron is not None: if direction == "both": return { key: value for key, value in self.outgoing(paired_neuron) if value.connection_type == connection_type } | { key: value for key, value in self.incoming(paired_neuron) if value.connection_type == connection_type } if direction == "in": return { key: value for key, value in self.incoming(paired_neuron) if value.connection_type == connection_type } if direction == "out": return { key: value for key, value in self.outgoing(paired_neuron) if value.connection_type == connection_type } raise ValueError('Direction must be either "both", "in", or "out"')
[docs] def get_connected_neurons( self, direction="both", weight_filter=1, connection_type="all" ): """Returns all connected neurons for this neuron.""" if connection_type == "all": if direction == "both": conns = self.in_connections | self.out_connections elif direction == "in": conns = self.in_connections elif direction == "out": conns = self.out_connections else: raise ValueError('Direction must be either "both", "in", or "out"') all_conns = [] for c, conn in conns.items(): if conn.weight > weight_filter: all_conns += [c[0]] all_conns += [c[1]] all_conns = set(all_conns) return all_conns else: if direction == "both": conns = self.in_connections | self.out_connections elif direction == "in": conns = self.in_connections elif direction == "out": conns = self.out_connections else: raise ValueError('Direction must be either "both", "in", or "out"') all_conns = [] for c, conn in conns.items(): if ( conn.weight > weight_filter and conn.connection_type == connection_type ): all_conns += [c[0]] all_conns += [c[1]] all_conns = set(all_conns) return all_conns
[docs] def update_connections(self): """ Updates the `in_connections` and `out_connections` dictionaries of the current object. """ self.in_connections = { _id: self.network.connections[_id] for _id in self.network.in_edges(self, keys=True) } self.out_connections = { _id: self.network.connections[_id] for _id in self.network.out_edges(self, keys=True) }
[docs] def outgoing(self, paired_neuron=None): """ Returns a list of all outgoing connections from the current object. :return: A list of connections from the current object to other objects. :rtype: list """ if paired_neuron is None: return self.out_connections if isinstance(paired_neuron, Neuron): return { edge: conn for edge, conn in self.out_connections.items() if edge[0] == self and edge[1] == paired_neuron } raise TypeError("paired_neuron must be a Neuron object")
[docs] def incoming(self, paired_neuron=None): """ Returns a list of all incoming connections to the current object. """ if paired_neuron is None: return self.in_connections if isinstance(paired_neuron, Neuron): return { edge: conn for edge, conn in self.in_connections.items() if edge[1] == self and edge[0] == paired_neuron } raise TypeError("paired_neuron must be a Neuron object")
[docs] def set_property(self, property_name, property_value): """ Sets a new property attribute for the class. Args: property_name (str): The name of the property. property_value: The value of the property. """ setattr(self, property_name, property_value) nx.set_node_attributes(self.network, {self: {property_name: property_value}})
[docs] def get_property(self, key): """Gets an attribute for the class""" return getattr(self, key)
[docs] def connects_to(self, other): """Checks if this neuron connects to another neuron""" for o in self.out_connections: if o[1] == other: return True for i in self.in_connections: if i[0] == other: return True return False
[docs] def paths_to(self, target, path_length=1): """ Returns all paths as a list of connections from this neuron to the target neuron """ path_list = [ self.network.groups[group] for group in self.network.groups if group.startswith(f"Path_{self.name}_{target.name}_length_{path_length}") ] paths = nx.all_simple_edge_paths(self.network, self, target, cutoff=path_length) connection_paths = [ [self.network.connections[edge] for edge in path] for path in paths ] if len(path_list) == len(connection_paths): return path_list else: return [ Path( self.network, path, f"Path_{self.name}_{target.name}_length_{path_length}_{j}", ) for j, path in enumerate(connection_paths) ]
[docs] def all_paths(self, path_length=1, direction="both"): """ Returns all paths as a list of connections from this neuron to all other neurons in the network """ if direction == "out": out_paths = [ nx.all_simple_edge_paths( self.network, self, self.network.neurons[n], cutoff=path_length ) for n in self.network.neurons ] connection_paths = [ [[self.network.connections[edge] for edge in path] for path in paths] for paths in out_paths ] return [ Path( self.network, path, f"Path_{self.name}_out_length_{path_length}_{j}_{k}", ) for k, paths in enumerate(connection_paths) for j, path in enumerate(paths) ] elif direction == "in": in_paths = [ nx.all_simple_edge_paths( self.network, self.network.neurons[n], self, cutoff=path_length ) for n in self.network.neurons ] connection_paths = [ [[self.network.connections[edge] for edge in path] for path in paths] for paths in in_paths ] return [ Path( self.network, path, f"Path_{self.name}_in_length_{path_length}_{j}_{k}", ) for k, paths in enumerate(connection_paths) for j, path in enumerate(paths) ] elif direction == "both": in_paths = [ nx.all_simple_edge_paths( self.network, self.network.neurons[n], self, cutoff=path_length ) for n in self.network.neurons ] out_paths = [ nx.all_simple_edge_paths( self.network, self, self.network.neurons[n], cutoff=path_length ) for n in self.network.neurons ] connection_paths_out = [ [[self.network.connections[edge] for edge in path] for path in paths] for paths in out_paths ] connection_paths_in = [ [[self.network.connections[edge] for edge in path] for path in paths] for paths in in_paths ] return [ Path( self.network, path, f"Path_{self.name}_out_length_{path_length}_{j}_{k}", ) for k, paths in enumerate(connection_paths_out) for j, path in enumerate(paths) ] + [ Path( self.network, path, f"Path_{self.name}_in_length_{path_length}_{j}_{k}", ) for k, paths in enumerate(connection_paths_in) for j, path in enumerate(paths) ]
def __str__(self): ## For use in debugging and testing return self.name
# def __repr__(self): # ## For use in debugging and testing # return self.name
[docs] class NeuronGroup(Citable): """This contains a group of neurons in the network""" def __init__(self, network, members=None, group_name=None) -> None: """ Initializes a new instance of the NeuronGroup class. Parameters: network (NervousSystem): The owning nervous system. Used to resolve string members. members (Iterable[Union[Neuron, str]], optional): The members of the group. Each entry may be a ``Neuron`` instance or a neuron name (string) — names are resolved via ``network.neurons``. Defaults to an empty group. group_name (str, optional): The name of the neuron group. Auto-generated if omitted. Returns: None """ Citable.__init__(self) # provides self.citations = {} if group_name is None: self.group_name = "Group-" + generate_random_string(8) else: self.group_name = group_name if members is None: members = [] else: resolved = [] for m in members: if isinstance(m, Neuron): resolved.append(m) elif isinstance(m, str): if m not in network.neurons: raise KeyError( f"Neuron {m!r} not found in network {network.name!r}" ) resolved.append(network.neurons[m]) else: raise TypeError( f"NeuronGroup members must be Neuron or str, got {type(m).__name__}" ) members = resolved self.members = members self.neurons = {m.name: m for m in members} self.network = network assert ( self.group_name not in self.network.groups ), f"Group name {self.group_name}\ already exists in the network" self.network.groups.update({self.group_name: self}) def _parent_citables(self): """Walk citations up to the containing NervousSystem.""" return (self.network,) if self.network is not None else () def __iter__(self): """ Returns an iterator over the members of the group. """ return iter(self.neurons)
[docs] def items(self): """ Returns an iterator over the members of the group. """ for key, value in self.neurons.items(): yield key, value
[docs] def keys(self): """ Returns an iterator over the members of the group. """ return list(self.neurons.keys())
[docs] def values(self): """ Returns an iterator over the members of the group. """ return list(self.neurons.values())
def __len__(self): """ Returns the number of members in the group. """ return len(self.neurons) def __contains__(self, neuron): """ Returns True if the neuron with the specified name is in the group, False otherwise. """ return neuron in self.neurons def __getitem__(self, neuron_name): """ Returns the neuron with the specified name in the group. """ return self.neurons[neuron_name] def __setitem__(self, neuron_name, neuron): """ Sets the neuron with the specified name in the group. """ assert isinstance(neuron, Neuron), "Neuron group members must be of type Neuron" self.neurons[neuron_name] = neuron
[docs] def clear(self): """ Removes all neurons from the group. """ self.neurons = {} self.members = []
[docs] def update(self, member_dict): """ Updates the list of members in the group. """ assert all( [isinstance(neuron, Neuron) for nname, neuron in member_dict.items()] ), "Neuron group members must be of type Neuron" self.neurons.update(member_dict) self.members = list(self.neurons.values())
[docs] def pop(self, neuron_name): """ Deletes the neuron with the specified name from the group. """ self.neurons.pop(neuron_name)
[docs] def set_property(self, property_name, property_value): """ Sets a new property attribute for all neurons in the group. """ for neuron in self.members: neuron.set_property(property_name, property_value)
[docs] def get_property(self, property_name): """ Returns the value of the specified property for all neurons in the group. """ return [neuron.get_property(property_name) for neuron in self.members]
[docs] def get_connections(self): """ Returns a list of all connections in the group. """ return [neuron.get_connections() for neuron in self.members]
[docs] def add_neuron(self, neuron: "Neuron") -> None: """Add a neuron to the group. Args: neuron: Neuron to add. """ if neuron not in self.neurons: self.neurons[neuron.name] = neuron
[docs] def remove_neuron(self, neuron: "Neuron") -> None: """Remove a neuron from the group. Args: neuron: Neuron to remove. """ if neuron in self.neurons: self.neurons.pop(neuron.name)
[docs] def get_neurons_by_type(self, type: str) -> List["Neuron"]: """Get all neurons of a specific type. Args: type: Neuron type to filter by. Returns: List of neurons of the specified type. """ return [n for n in self.neurons.values() if n.type == type]
[docs] def get_neurons_by_property(self, key: str, value: Any) -> List["Neuron"]: """Get neurons with a specific property value. Args: key: Property name. value: Property value to match. Returns: List of neurons with matching property value. """ return [n for n in self.neurons.values() if n.get_property(key) == value]
# ---- Set operations ----
[docs] def union(self, other: "NeuronGroup", group_name: str = None) -> "NeuronGroup": """Return a new NeuronGroup containing neurons from both groups. Args: other: Another NeuronGroup to union with. group_name: Optional name for the resulting group. Returns: A new NeuronGroup with the combined members. """ assert isinstance(other, NeuronGroup), "Operand must be a NeuronGroup" assert ( self.network is other.network ), "Both groups must belong to the same network" merged = {**self.neurons, **other.neurons} name = group_name or f"{self.group_name}_union_{other.group_name}" return NeuronGroup(self.network, members=list(merged.values()), group_name=name)
[docs] def intersection( self, other: "NeuronGroup", group_name: str = None ) -> "NeuronGroup": """Return a new NeuronGroup containing only neurons present in both groups. Args: other: Another NeuronGroup to intersect with. group_name: Optional name for the resulting group. Returns: A new NeuronGroup with the shared members. """ assert isinstance(other, NeuronGroup), "Operand must be a NeuronGroup" assert ( self.network is other.network ), "Both groups must belong to the same network" common_keys = set(self.neurons.keys()) & set(other.neurons.keys()) members = [self.neurons[k] for k in common_keys] name = group_name or f"{self.group_name}_intersect_{other.group_name}" return NeuronGroup(self.network, members=members, group_name=name)
[docs] def difference(self, other: "NeuronGroup", group_name: str = None) -> "NeuronGroup": """Return a new NeuronGroup containing neurons in self but not in other. Args: other: Another NeuronGroup to subtract. group_name: Optional name for the resulting group. Returns: A new NeuronGroup with the difference members. """ assert isinstance(other, NeuronGroup), "Operand must be a NeuronGroup" assert ( self.network is other.network ), "Both groups must belong to the same network" diff_keys = set(self.neurons.keys()) - set(other.neurons.keys()) members = [self.neurons[k] for k in diff_keys] name = group_name or f"{self.group_name}_diff_{other.group_name}" return NeuronGroup(self.network, members=members, group_name=name)
def __or__(self, other: "NeuronGroup") -> "NeuronGroup": """Operator ``|`` — union.""" return self.union(other) def __and__(self, other: "NeuronGroup") -> "NeuronGroup": """Operator ``&`` — intersection.""" return self.intersection(other) def __sub__(self, other: "NeuronGroup") -> "NeuronGroup": """Operator ``-`` — difference.""" return self.difference(other)