Source code for cedne.core.recordings

"""
Trial-wise neural activity and stimulus-response analysis for CeDNe.

This module defines classes for capturing and analyzing experimental recordings
from neurons. It includes:

- `Trial`: Represents a single experimental recording for a neuron, such as
  a calcium imaging time series. Trials are stored per-neuron in the `Neuron.trial`
  dictionary and support signal preprocessing (e.g., bleaching correction).

- `StimResponse`: Encapsulates a stimulus-response pair recorded during a `Trial`,
  and extracts a set of interpretable features from the response signal, including
  max amplitude, onset time, area under the curve, and others.

These classes are designed to support time-locked calcium imaging experiments
and help link dynamic neural activity to behavioral or stimulus-driven contexts.
"""

__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"

from typing import Optional, List, Union, Dict, Any, TYPE_CHECKING
from datetime import datetime
import numpy as np
from scipy import signal
import scipy.stats as ss
from scipy.ndimage import gaussian_filter1d
from numpy.typing import NDArray
from .config import F_SAMPLE

if TYPE_CHECKING:
    from .neuron import Neuron


[docs] class Trial: """A class representing a single experimental recording for a neuron. This class handles the storage and basic processing of time series data recorded from a neuron during an experimental trial. It supports operations like data storage, bleaching correction, and basic signal processing. Attributes: neuron: The neuron object associated with this trial. i (int): The trial number. discard (List[int]): Points to be discarded due to bleaching or artifacts. _data (NDArray): The actual recording data. metadata (Dict): Dictionary containing trial metadata. """ def __init__(self, neuron: "Neuron", trialnum: int) -> None: """Initialize a new Trial instance. Args: neuron: The neuron object associated with this trial. trialnum: The trial number identifier. """ self.neuron = neuron self.i = trialnum self.discard: List[int] = [] self._data: Optional[NDArray] = None self.behavior = None # Optional back-reference to Behavior object self.metadata: Dict[str, Any] = { "trial_number": trialnum, "neuron_id": id(neuron), "sampling_rate": F_SAMPLE, "processing_history": [], } @property def recording(self) -> NDArray: """Get the recording data for the trial. Returns: NDArray: The recording time series data. Raises: ValueError: If no recording data has been set. """ if self._data is None: raise ValueError("No recording data has been set") return self._data @recording.setter def recording(self, data: NDArray, discard: float = 0) -> None: """Set the recording data for the trial. Args: data: The time series data to be recorded. discard: Number of initial seconds to discard (e.g., for bleaching correction). Raises: ValueError: If discard is negative or if data is invalid. """ if not isinstance(data, np.ndarray): data = np.array(data) if data.ndim != 1: raise ValueError("Recording data must be a 1D array") if discard < 0: raise ValueError("Discard cannot be negative") if discard > 0: discard_points = int(discard * F_SAMPLE) if discard_points >= len(data): raise ValueError("Discard duration exceeds data length") self.discard = list(range(discard_points)) self._data = data[discard_points:].astype(np.float64) else: self.discard = [] self._data = data.astype(np.float64)
[docs] def get_duration(self) -> float: """Get the duration of the recording in seconds. Returns: float: Duration of the recording in seconds. Raises: ValueError: If no recording data has been set. """ return len(self.recording) / F_SAMPLE
[docs] def get_timestamps(self) -> NDArray: """Get the timestamps for each sample in the recording. Returns: NDArray: Array of timestamps in seconds. Raises: ValueError: If no recording data has been set. """ return np.arange(len(self.recording)) / F_SAMPLE
[docs] def filter_signal( self, filter_type: str = "lowpass", cutoff_freq: float = 10.0, order: int = 4 ) -> NDArray: """Apply a Butterworth filter to the recording data. Args: filter_type: Type of filter ('lowpass', 'highpass', or 'bandpass'). cutoff_freq: Cutoff frequency in Hz. For bandpass, provide tuple (low, high). order: Order of the Butterworth filter. Returns: NDArray: Filtered signal. Raises: ValueError: If filter_type is invalid or if no recording data has been set. """ nyquist = F_SAMPLE / 2 if isinstance(cutoff_freq, (list, tuple)): cutoff_freq = np.array(cutoff_freq) if np.any(cutoff_freq <= 0) or np.any(cutoff_freq >= nyquist): raise ValueError( "Cutoff frequencies must be between 0 and nyquist frequency" ) else: if cutoff_freq <= 0 or cutoff_freq >= nyquist: raise ValueError( "Cutoff frequency must be between 0 and nyquist frequency" ) normalized_cutoff = cutoff_freq / nyquist if filter_type == "lowpass": b, a = signal.butter(order, normalized_cutoff, btype="low") elif filter_type == "highpass": b, a = signal.butter(order, normalized_cutoff, btype="high") elif filter_type == "bandpass": if ( not isinstance(normalized_cutoff, (list, tuple, np.ndarray)) or len(normalized_cutoff) != 2 ): raise ValueError("Bandpass filter requires two cutoff frequencies") b, a = signal.butter(order, normalized_cutoff, btype="band") else: raise ValueError(f"Invalid filter type: {filter_type}") return signal.filtfilt(b, a, self.recording)
[docs] def smooth_signal(self, window_size: int = 5, method: str = "moving") -> NDArray: """Smooth the recording signal using various methods. Args: window_size: Size of the smoothing window in samples. method: Smoothing method ('moving', 'gaussian', or 'median'). Returns: NDArray: Smoothed signal. Raises: ValueError: If method is invalid or if no recording data has been set. """ if method == "moving": window = np.ones(window_size) / window_size return np.convolve(self.recording, window, mode="same") elif method == "gaussian": return gaussian_filter1d(self.recording, window_size) elif method == "median": return signal.medfilt(self.recording, window_size) else: raise ValueError(f"Invalid smoothing method: {method}")
[docs] def normalize_signal( self, method: str = "minmax", baseline_window: Optional[tuple] = None ) -> NDArray: """Normalize the recording signal using various methods. Args: method: Normalization method ('minmax', 'zscore', or 'baseline'). baseline_window: Tuple of (start, end) indices for baseline normalization. Returns: NDArray: Normalized signal. Raises: ValueError: If method is invalid or if no recording data has been set. """ if method == "minmax": min_val = np.min(self.recording) max_val = np.max(self.recording) return (self.recording - min_val) / (max_val - min_val) elif method == "zscore": return (self.recording - np.mean(self.recording)) / np.std(self.recording) elif method == "baseline": if baseline_window is None: raise ValueError( "baseline_window must be provided for baseline normalization" ) start, end = baseline_window baseline = np.mean(self.recording[start:end]) return (self.recording - baseline) / baseline else: raise ValueError(f"Invalid normalization method: {method}")
[docs] def detect_peaks( self, height: Optional[float] = None, distance: Optional[int] = None ) -> tuple[NDArray, NDArray]: """Detect peaks in the recording signal. Args: height: Minimum height of peaks. distance: Minimum distance between peaks in samples. Returns: tuple: (peak_indices, peak_heights) Raises: ValueError: If no recording data has been set. """ if distance is not None and distance < 1: raise ValueError("`distance` must be greater or equal to 1") peaks, properties = signal.find_peaks( self.recording, height=height, distance=distance ) peak_heights = ( self.recording[peaks] if "peak_heights" not in properties else properties["peak_heights"] ) return peaks, peak_heights
[docs] def get_statistics(self) -> dict: """Calculate basic statistics of the recording. Returns: dict: Dictionary containing various statistical measures: - mean: Mean of the signal - std: Standard deviation - median: Median value - min: Minimum value - max: Maximum value - skewness: Skewness of the distribution - kurtosis: Kurtosis of the distribution - rms: Root mean square value Raises: ValueError: If no recording data has been set. """ return { "mean": np.mean(self.recording), "std": np.std(self.recording), "median": np.median(self.recording), "min": np.min(self.recording), "max": np.max(self.recording), "skewness": ss.skew(self.recording), "kurtosis": ss.kurtosis(self.recording), "rms": np.sqrt(np.mean(np.square(self.recording))), }
[docs] def compute_power_spectrum(self, window: str = "hann") -> tuple[NDArray, NDArray]: """Compute the power spectrum of the recording. Args: window: Window function to use ('hann', 'hamming', 'blackman', etc.). Returns: tuple: Arrays of frequencies and corresponding power spectrum. Raises: ValueError: If no recording data has been set. """ freqs, psd = signal.welch(self.recording, F_SAMPLE, window=window) return freqs, psd
[docs] def compute_snr(self, signal_window: tuple, noise_window: tuple) -> float: """Compute the signal-to-noise ratio. Args: signal_window: Tuple of (start, end) indices for signal region. noise_window: Tuple of (start, end) indices for noise region. Returns: float: Signal-to-noise ratio in dB. Raises: ValueError: If windows are invalid or if no recording data has been set. """ sig_start, sig_end = signal_window noise_start, noise_end = noise_window signal_power = np.mean(np.square(self.recording[sig_start:sig_end])) noise_power = np.mean(np.square(self.recording[noise_start:noise_end])) if noise_power == 0: return float("inf") return 10 * np.log10(signal_power / noise_power)
[docs] def segment_signal( self, threshold: float, min_duration: int = 10 ) -> List[tuple[int, int]]: """Segment the signal into regions above threshold. Args: threshold: Amplitude threshold for segmentation. min_duration: Minimum duration (in samples) for a valid segment. Returns: List[tuple]: List of (start, end) indices for each segment. Raises: ValueError: If no recording data has been set. """ above_threshold = self.recording > threshold changes = np.diff(above_threshold.astype(int)) rise_points = np.where(changes == 1)[0] + 1 fall_points = np.where(changes == -1)[0] + 1 # Handle edge cases if len(rise_points) == 0 or len(fall_points) == 0: return [] if rise_points[0] > fall_points[0]: rise_points = np.insert(rise_points, 0, 0) if rise_points[-1] > fall_points[-1]: fall_points = np.append(fall_points, len(self.recording)) segments = [] for start, end in zip(rise_points, fall_points): if end - start >= min_duration: segments.append((start, end)) return segments
[docs] def add_metadata(self, key: str, value: Any) -> None: """Add metadata to the trial. Args: key: Metadata key. value: Metadata value. """ self.metadata[key] = value
[docs] def get_metadata(self, key: str) -> Any: """Get metadata value. Args: key: Metadata key. Returns: The metadata value. Raises: KeyError: If the key doesn't exist in metadata. """ return self.metadata[key]
[docs] def log_processing(self, operation: str, parameters: Dict[str, Any]) -> None: """Log a processing operation in the trial's history. Args: operation: Name of the processing operation. parameters: Dictionary of parameters used in the operation. """ self.metadata["processing_history"].append( { "timestamp": datetime.now().isoformat(), "operation": operation, "parameters": parameters, } )
[docs] def validate_data(self) -> bool: """Validate the recording data. Returns: bool: True if data is valid, False otherwise. This method checks: - Data is not None - Data is a numpy array - Data is 1-dimensional - Data contains no NaN or infinite values - Data length is reasonable (> 0) """ if self._data is None: return False if not isinstance(self._data, np.ndarray): return False if self._data.ndim != 1: return False if not np.isfinite(self._data).all(): return False if len(self._data) == 0: return False return True
[docs] def get_quality_metrics(self) -> Dict[str, float]: """Calculate quality metrics for the recording. Returns: dict: Dictionary containing quality metrics: - snr: Signal-to-noise ratio - noise_level: Estimated noise level - signal_stability: Measure of signal stability - artifact_count: Number of potential artifacts Raises: ValueError: If no recording data has been set. """ if not self.validate_data(): raise ValueError("Invalid or missing recording data") # Calculate noise level from the first 10% of the signal noise_window = slice(0, len(self._data) // 10) noise_level = np.std(self._data[noise_window]) # Detect potential artifacts (points > 3 std from mean) mean = np.mean(self._data) std = np.std(self._data) artifacts = np.sum(np.abs(self._data - mean) > 3 * std) # Calculate signal stability (variation in signal segments) segment_length = len(self._data) // 10 segments = np.array_split(self._data, 10) segment_means = [np.mean(seg) for seg in segments] stability = 1 - np.std(segment_means) / np.mean(segment_means) return { "noise_level": float(noise_level), "artifact_count": int(artifacts), "signal_stability": float(stability), }
[docs] class StimResponse: """A class representing a stimulus-response pair in a neural recording. This class handles the analysis of neural responses to specific stimuli, extracting various features from the response signal and providing methods for response characterization. Attributes: stim (NDArray): The stimulus signal. response (NDArray): The response signal. feature (Dict[int, Any]): Dictionary of extracted features. neuron: The neuron object associated with this response. f_sample (float): Sampling frequency in Hz. sampling_time (float): Time between samples in seconds. baseline (NDArray): Baseline signal before stimulus. Features extracted include: 0: Maximum response amplitude 1: Area under the curve 2: Mean response 3: Time to peak 4: Area under the curve to peak 5: Minimum response 6: Response onset time 7: Positive response area 8: Absolute area under the curve """ def __init__( self, trial: Trial, stimulus: NDArray, response: NDArray, baseline_samples: int ) -> None: """Initialize a StimResponse instance. Args: trial: The trial object associated with this response. stimulus: The stimulus signal. response: The response signal. baseline_samples: Number of samples to use for baseline calculation. Raises: ValueError: If input arrays have invalid dimensions or lengths. """ if not isinstance(stimulus, np.ndarray) or not isinstance(response, np.ndarray): raise ValueError("Stimulus and response must be numpy arrays") if stimulus.ndim != 1 or response.ndim != 1: raise ValueError("Stimulus and response must be 1-dimensional") if len(stimulus) != len(response): raise ValueError("Stimulus and response must have the same length") if baseline_samples >= len(response): raise ValueError("Baseline samples exceeds response length") self.stim = stimulus self.response = response self.feature: Dict[int, Any] = {} self.neuron = trial.neuron self.f_sample = F_SAMPLE self.sampling_time = 1.0 / self.f_sample self.baseline = self.response[:baseline_samples] # Extract all features for feature_index in range(9): # 9 features total self.feature[feature_index] = self.extract_feature(feature_index)
[docs] def extract_feature(self, feature_index: int) -> Union[float, tuple[float, float]]: """Extract a specific feature from the stimulus-response pair. Args: feature_index: Index of the feature to extract: 0: Maximum value 1: Area under the curve 2: Time to peak 3: Mean value 4: Area under the curve to peak 5: Minimum value 6: Onset time 7: Positive area 8: Absolute area under the curve Returns: The extracted feature value or tuple of values. Raises: ValueError: If feature_index is invalid. """ feature_mapping = { 0: self._find_maximum, 1: self._area_under_the_curve, 2: self._find_time_to_peak, 3: self._find_mean, 4: self._area_under_the_curve_to_peak, 5: self._find_minimum, 6: self._find_onset_time, 7: self._find_positive_area, 8: self._absolute_area_under_the_curve, } if feature_index not in feature_mapping: raise ValueError(f"Invalid feature index: {feature_index}") return feature_mapping[feature_index]()
def _find_maximum(self) -> float: """Find the maximum response amplitude. Returns: float: Maximum value of the response signal. """ return float(np.max(self.response)) def _find_minimum(self) -> float: """Find the minimum response amplitude. Returns: float: Minimum value of the response signal. """ return float(np.min(self.response)) def _find_time_to_peak(self) -> float: """Find the time to response peak. Returns: float: Time to peak in seconds. """ max_index = np.argmax(self.response) return float(max_index * self.sampling_time) def _find_mean(self) -> float: """Calculate the mean response amplitude. Returns: float: Mean value of the response signal. """ return float(np.mean(self.response)) def _area_under_the_curve(self, bin_size: int = 5) -> float: """Calculate the total area under the response curve. Args: bin_size: Number of samples to bin for integration. Returns: float: Area under the curve in amplitude-seconds. """ undersampling = self.response[::bin_size] return float(np.trapz(undersampling, dx=self.sampling_time * bin_size)) def _absolute_area_under_the_curve(self, bin_size: int = 5) -> float: """Calculate the absolute area under the response curve. Args: bin_size: Number of samples to bin for integration. Returns: float: Absolute area under the curve in amplitude-seconds. """ undersampling = np.abs(self.response[::bin_size]) return float(np.trapz(undersampling, dx=self.sampling_time * bin_size)) def _area_under_the_curve_to_peak(self, bin_size: int = 10) -> float: """Calculate the area under the curve up to the peak response. Args: bin_size: Number of samples to bin for integration. Returns: float: Area under the curve to peak in amplitude-seconds. """ undersampling = self.response[::bin_size] max_index = np.argmax(undersampling) window_to_peak = undersampling[: max_index + 1] return float(np.trapz(window_to_peak, dx=self.sampling_time * bin_size)) def _find_onset_time( self, window_size: int = 10, threshold_std: float = 2.0, absolute_threshold: Optional[float] = None, ) -> float: """Find the response onset time using a sliding window approach. For calcium imaging data, this method uses a robust detection approach that: 1. Requires the signal to stay above threshold for the entire window 2. Checks that the signal continues to rise after the window 3. Ensures the detected onset is not just a noise spike 4. Requires a moderate increase in signal level Args: window_size: Size of the sliding window in samples. threshold_std: Number of standard deviations above baseline for onset. Only used if absolute_threshold is None. absolute_threshold: Absolute threshold value for onset detection. If provided, overrides threshold_std calculation. Returns: float: Onset time in seconds. """ baseline_mean = np.mean(self.baseline) baseline_std = np.std(self.baseline) # Calculate threshold based on provided method if absolute_threshold is not None: threshold = absolute_threshold else: # Use statistical threshold threshold = baseline_mean + threshold_std * baseline_std # Use a sliding window to find when response consistently exceeds threshold for i in range(len(self.response) - window_size): window = self.response[i : i + window_size] if np.mean(window) > threshold: return float(i * self.sampling_time) # If no onset found, return nan return np.nan def _find_positive_area(self, bin_size: int = 5) -> tuple[float, float]: """Calculate the positive and negative areas of the response. Args: bin_size: Number of samples to bin for integration. Returns: tuple: (positive_area, negative_area) in amplitude-seconds. """ undersampling = self.response[::bin_size] positive_trace = np.clip(undersampling, 0, None) negative_trace = np.clip(-undersampling, 0, None) dx = self.sampling_time * bin_size pos_area = float(np.trapz(positive_trace, dx=dx)) neg_area = float(np.trapz(negative_trace, dx=dx)) abs_area = float(np.trapz(np.abs(undersampling), dx=dx)) signed_total = pos_area + neg_area if signed_total > abs_area and np.isclose(signed_total, abs_area): scale = abs_area / signed_total pos_area *= scale neg_area *= scale return pos_area, neg_area
[docs] def get_response_characteristics(self) -> Dict[str, float]: """Calculate comprehensive response characteristics. Returns: dict: Dictionary containing various response metrics: - amplitude: Peak response amplitude (relative to baseline) - duration: Response duration - latency: Response latency - integral: Total response integral - baseline_mean: Mean baseline activity - baseline_std: Baseline standard deviation - signal_to_noise: Signal-to-noise ratio """ baseline_mean = np.mean(self.baseline) baseline_std = np.std(self.baseline) # Calculate peak amplitude relative to baseline # First subtract baseline from the entire response response_minus_baseline = self.response - baseline_mean # Find the maximum absolute deviation max_abs_idx = np.argmax(np.abs(response_minus_baseline)) # Use the actual value at that index peak_amplitude = response_minus_baseline[max_abs_idx] onset_time = self._find_onset_time() # Find response end (when signal returns to baseline) end_threshold = baseline_mean + 2 * baseline_std response_end = np.where(self.response <= end_threshold)[0] response_end = response_end[-1] if len(response_end) > 0 else len(self.response) duration = ( (response_end - onset_time) * self.sampling_time if not np.isnan(onset_time) else 0.0 ) integral = self._area_under_the_curve() # Calculate signal-to-noise ratio signal_power = np.mean(np.square(self.response - baseline_mean)) noise_power = np.mean(np.square(self.baseline - baseline_mean)) snr = ( 10 * np.log10(signal_power / noise_power) if noise_power > 0 else float("inf") ) return { "amplitude": float(peak_amplitude), "duration": float(duration), "latency": float(onset_time), "integral": float(integral), "baseline_mean": float(baseline_mean), "baseline_std": float(baseline_std), "signal_to_noise": float(snr), }
def _linear_transform(value, minvalue, maxvalue): return (value - minvalue) / (maxvalue - minvalue)