"""
Trial-wise neural activity and stimulus-response analysis for CeDNe.
This module defines classes for capturing and analyzing experimental recordings
from neurons. It includes:
- `Trial`: Represents a single experimental recording for a neuron, such as
a calcium imaging time series. Trials are stored per-neuron in the `Neuron.trial`
dictionary and support signal preprocessing (e.g., bleaching correction).
- `StimResponse`: Encapsulates a stimulus-response pair recorded during a `Trial`,
and extracts a set of interpretable features from the response signal, including
max amplitude, onset time, area under the curve, and others.
These classes are designed to support time-locked calcium imaging experiments
and help link dynamic neural activity to behavioral or stimulus-driven contexts.
"""
__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"
from typing import Optional, List, Union, Dict, Any, TYPE_CHECKING
from datetime import datetime
import numpy as np
from scipy import signal
import scipy.stats as ss
from scipy.ndimage import gaussian_filter1d
from numpy.typing import NDArray
from .config import F_SAMPLE
if TYPE_CHECKING:
from .neuron import Neuron
[docs]
class Trial:
"""A class representing a single experimental recording for a neuron.
This class handles the storage and basic processing of time series data recorded
from a neuron during an experimental trial. It supports operations like data
storage, bleaching correction, and basic signal processing.
Attributes:
neuron: The neuron object associated with this trial.
i (int): The trial number.
discard (List[int]): Points to be discarded due to bleaching or artifacts.
_data (NDArray): The actual recording data.
metadata (Dict): Dictionary containing trial metadata.
"""
def __init__(self, neuron: 'Neuron', trialnum: int) -> None:
"""Initialize a new Trial instance.
Args:
neuron: The neuron object associated with this trial.
trialnum: The trial number identifier.
"""
self.neuron = neuron
self.i = trialnum
self.discard: List[int] = []
self._data: Optional[NDArray] = None
self.metadata: Dict[str, Any] = {
'trial_number': trialnum,
'neuron_id': id(neuron),
'sampling_rate': F_SAMPLE,
'processing_history': []
}
@property
def recording(self) -> NDArray:
"""Get the recording data for the trial.
Returns:
NDArray: The recording time series data.
Raises:
ValueError: If no recording data has been set.
"""
if self._data is None:
raise ValueError("No recording data has been set")
return self._data
@recording.setter
def recording(self, data: NDArray, discard: float = 0) -> None:
"""Set the recording data for the trial.
Args:
data: The time series data to be recorded.
discard: Number of initial seconds to discard (e.g., for bleaching correction).
Raises:
ValueError: If discard is negative or if data is invalid.
"""
if not isinstance(data, np.ndarray):
data = np.array(data)
if data.ndim != 1:
raise ValueError("Recording data must be a 1D array")
if discard < 0:
raise ValueError("Discard cannot be negative")
if discard > 0:
discard_points = int(discard * F_SAMPLE)
if discard_points >= len(data):
raise ValueError("Discard duration exceeds data length")
self.discard = list(range(discard_points))
self._data = data[discard_points:].astype(np.float64)
else:
self.discard = []
self._data = data.astype(np.float64)
[docs]
def get_duration(self) -> float:
"""Get the duration of the recording in seconds.
Returns:
float: Duration of the recording in seconds.
Raises:
ValueError: If no recording data has been set.
"""
return len(self.recording) / F_SAMPLE
[docs]
def get_timestamps(self) -> NDArray:
"""Get the timestamps for each sample in the recording.
Returns:
NDArray: Array of timestamps in seconds.
Raises:
ValueError: If no recording data has been set.
"""
return np.arange(len(self.recording)) / F_SAMPLE
[docs]
def filter_signal(self, filter_type: str = 'lowpass', cutoff_freq: float = 10.0,
order: int = 4) -> NDArray:
"""Apply a Butterworth filter to the recording data.
Args:
filter_type: Type of filter ('lowpass', 'highpass', or 'bandpass').
cutoff_freq: Cutoff frequency in Hz. For bandpass, provide tuple (low, high).
order: Order of the Butterworth filter.
Returns:
NDArray: Filtered signal.
Raises:
ValueError: If filter_type is invalid or if no recording data has been set.
"""
nyquist = F_SAMPLE / 2
if isinstance(cutoff_freq, (list, tuple)):
cutoff_freq = np.array(cutoff_freq)
if np.any(cutoff_freq <= 0) or np.any(cutoff_freq >= nyquist):
raise ValueError("Cutoff frequencies must be between 0 and nyquist frequency")
else:
if cutoff_freq <= 0 or cutoff_freq >= nyquist:
raise ValueError("Cutoff frequency must be between 0 and nyquist frequency")
normalized_cutoff = cutoff_freq / nyquist
if filter_type == 'lowpass':
b, a = signal.butter(order, normalized_cutoff, btype='low')
elif filter_type == 'highpass':
b, a = signal.butter(order, normalized_cutoff, btype='high')
elif filter_type == 'bandpass':
if not isinstance(normalized_cutoff, (list, tuple, np.ndarray)) or len(normalized_cutoff) != 2:
raise ValueError("Bandpass filter requires two cutoff frequencies")
b, a = signal.butter(order, normalized_cutoff, btype='band')
else:
raise ValueError(f"Invalid filter type: {filter_type}")
return signal.filtfilt(b, a, self.recording)
[docs]
def smooth_signal(self, window_size: int = 5, method: str = 'moving') -> NDArray:
"""Smooth the recording signal using various methods.
Args:
window_size: Size of the smoothing window in samples.
method: Smoothing method ('moving', 'gaussian', or 'median').
Returns:
NDArray: Smoothed signal.
Raises:
ValueError: If method is invalid or if no recording data has been set.
"""
if method == 'moving':
window = np.ones(window_size) / window_size
return np.convolve(self.recording, window, mode='same')
elif method == 'gaussian':
return gaussian_filter1d(self.recording, window_size)
elif method == 'median':
return signal.medfilt(self.recording, window_size)
else:
raise ValueError(f"Invalid smoothing method: {method}")
[docs]
def normalize_signal(self, method: str = 'minmax', baseline_window: Optional[tuple] = None) -> NDArray:
"""Normalize the recording signal using various methods.
Args:
method: Normalization method ('minmax', 'zscore', or 'baseline').
baseline_window: Tuple of (start, end) indices for baseline normalization.
Returns:
NDArray: Normalized signal.
Raises:
ValueError: If method is invalid or if no recording data has been set.
"""
if method == 'minmax':
min_val = np.min(self.recording)
max_val = np.max(self.recording)
return (self.recording - min_val) / (max_val - min_val)
elif method == 'zscore':
return (self.recording - np.mean(self.recording)) / np.std(self.recording)
elif method == 'baseline':
if baseline_window is None:
raise ValueError("baseline_window must be provided for baseline normalization")
start, end = baseline_window
baseline = np.mean(self.recording[start:end])
return (self.recording - baseline) / baseline
else:
raise ValueError(f"Invalid normalization method: {method}")
[docs]
def detect_peaks(self, height: Optional[float] = None,
distance: Optional[int] = None) -> tuple[NDArray, NDArray]:
"""Detect peaks in the recording signal.
Args:
height: Minimum height of peaks.
distance: Minimum distance between peaks in samples.
Returns:
tuple: (peak_indices, peak_heights)
Raises:
ValueError: If no recording data has been set.
"""
if distance is not None and distance < 1:
raise ValueError("`distance` must be greater or equal to 1")
peaks, properties = signal.find_peaks(self.recording, height=height, distance=distance)
peak_heights = self.recording[peaks] if 'peak_heights' not in properties else properties['peak_heights']
return peaks, peak_heights
[docs]
def get_statistics(self) -> dict:
"""Calculate basic statistics of the recording.
Returns:
dict: Dictionary containing various statistical measures:
- mean: Mean of the signal
- std: Standard deviation
- median: Median value
- min: Minimum value
- max: Maximum value
- skewness: Skewness of the distribution
- kurtosis: Kurtosis of the distribution
- rms: Root mean square value
Raises:
ValueError: If no recording data has been set.
"""
return {
'mean': np.mean(self.recording),
'std': np.std(self.recording),
'median': np.median(self.recording),
'min': np.min(self.recording),
'max': np.max(self.recording),
'skewness': ss.skew(self.recording),
'kurtosis': ss.kurtosis(self.recording),
'rms': np.sqrt(np.mean(np.square(self.recording)))
}
[docs]
def compute_power_spectrum(self, window: str = 'hann') -> tuple[NDArray, NDArray]:
"""Compute the power spectrum of the recording.
Args:
window: Window function to use ('hann', 'hamming', 'blackman', etc.).
Returns:
tuple: Arrays of frequencies and corresponding power spectrum.
Raises:
ValueError: If no recording data has been set.
"""
freqs, psd = signal.welch(self.recording, F_SAMPLE, window=window)
return freqs, psd
[docs]
def compute_snr(self, signal_window: tuple, noise_window: tuple) -> float:
"""Compute the signal-to-noise ratio.
Args:
signal_window: Tuple of (start, end) indices for signal region.
noise_window: Tuple of (start, end) indices for noise region.
Returns:
float: Signal-to-noise ratio in dB.
Raises:
ValueError: If windows are invalid or if no recording data has been set.
"""
sig_start, sig_end = signal_window
noise_start, noise_end = noise_window
signal_power = np.mean(np.square(self.recording[sig_start:sig_end]))
noise_power = np.mean(np.square(self.recording[noise_start:noise_end]))
if noise_power == 0:
return float('inf')
return 10 * np.log10(signal_power / noise_power)
[docs]
def segment_signal(self, threshold: float, min_duration: int = 10) -> List[tuple[int, int]]:
"""Segment the signal into regions above threshold.
Args:
threshold: Amplitude threshold for segmentation.
min_duration: Minimum duration (in samples) for a valid segment.
Returns:
List[tuple]: List of (start, end) indices for each segment.
Raises:
ValueError: If no recording data has been set.
"""
above_threshold = self.recording > threshold
changes = np.diff(above_threshold.astype(int))
rise_points = np.where(changes == 1)[0] + 1
fall_points = np.where(changes == -1)[0] + 1
# Handle edge cases
if len(rise_points) == 0 or len(fall_points) == 0:
return []
if rise_points[0] > fall_points[0]:
rise_points = np.insert(rise_points, 0, 0)
if rise_points[-1] > fall_points[-1]:
fall_points = np.append(fall_points, len(self.recording))
segments = []
for start, end in zip(rise_points, fall_points):
if end - start >= min_duration:
segments.append((start, end))
return segments
[docs]
def log_processing(self, operation: str, parameters: Dict[str, Any]) -> None:
"""Log a processing operation in the trial's history.
Args:
operation: Name of the processing operation.
parameters: Dictionary of parameters used in the operation.
"""
self.metadata['processing_history'].append({
'timestamp': datetime.now().isoformat(),
'operation': operation,
'parameters': parameters
})
[docs]
def validate_data(self) -> bool:
"""Validate the recording data.
Returns:
bool: True if data is valid, False otherwise.
This method checks:
- Data is not None
- Data is a numpy array
- Data is 1-dimensional
- Data contains no NaN or infinite values
- Data length is reasonable (> 0)
"""
if self._data is None:
return False
if not isinstance(self._data, np.ndarray):
return False
if self._data.ndim != 1:
return False
if not np.isfinite(self._data).all():
return False
if len(self._data) == 0:
return False
return True
[docs]
def get_quality_metrics(self) -> Dict[str, float]:
"""Calculate quality metrics for the recording.
Returns:
dict: Dictionary containing quality metrics:
- snr: Signal-to-noise ratio
- noise_level: Estimated noise level
- signal_stability: Measure of signal stability
- artifact_count: Number of potential artifacts
Raises:
ValueError: If no recording data has been set.
"""
if not self.validate_data():
raise ValueError("Invalid or missing recording data")
# Calculate noise level from the first 10% of the signal
noise_window = slice(0, len(self._data) // 10)
noise_level = np.std(self._data[noise_window])
# Detect potential artifacts (points > 3 std from mean)
mean = np.mean(self._data)
std = np.std(self._data)
artifacts = np.sum(np.abs(self._data - mean) > 3 * std)
# Calculate signal stability (variation in signal segments)
segment_length = len(self._data) // 10
segments = np.array_split(self._data, 10)
segment_means = [np.mean(seg) for seg in segments]
stability = 1 - np.std(segment_means) / np.mean(segment_means)
return {
'noise_level': float(noise_level),
'artifact_count': int(artifacts),
'signal_stability': float(stability)
}
[docs]
class StimResponse:
"""A class representing a stimulus-response pair in a neural recording.
This class handles the analysis of neural responses to specific stimuli,
extracting various features from the response signal and providing methods
for response characterization.
Attributes:
stim (NDArray): The stimulus signal.
response (NDArray): The response signal.
feature (Dict[int, Any]): Dictionary of extracted features.
neuron: The neuron object associated with this response.
f_sample (float): Sampling frequency in Hz.
sampling_time (float): Time between samples in seconds.
baseline (NDArray): Baseline signal before stimulus.
Features extracted include:
0: Maximum response amplitude
1: Area under the curve
2: Mean response
3: Time to peak
4: Area under the curve to peak
5: Minimum response
6: Response onset time
7: Positive response area
8: Absolute area under the curve
"""
def __init__(self, trial: Trial, stimulus: NDArray, response: NDArray,
baseline_samples: int) -> None:
"""Initialize a StimResponse instance.
Args:
trial: The trial object associated with this response.
stimulus: The stimulus signal.
response: The response signal.
baseline_samples: Number of samples to use for baseline calculation.
Raises:
ValueError: If input arrays have invalid dimensions or lengths.
"""
if not isinstance(stimulus, np.ndarray) or not isinstance(response, np.ndarray):
raise ValueError("Stimulus and response must be numpy arrays")
if stimulus.ndim != 1 or response.ndim != 1:
raise ValueError("Stimulus and response must be 1-dimensional")
if len(stimulus) != len(response):
raise ValueError("Stimulus and response must have the same length")
if baseline_samples >= len(response):
raise ValueError("Baseline samples exceeds response length")
self.stim = stimulus
self.response = response
self.feature: Dict[int, Any] = {}
self.neuron = trial.neuron
self.f_sample = F_SAMPLE
self.sampling_time = 1./self.f_sample
self.baseline = self.response[:baseline_samples]
# Extract all features
for feature_index in range(9): # 9 features total
self.feature[feature_index] = self.extract_feature(feature_index)
def _find_maximum(self) -> float:
"""Find the maximum response amplitude.
Returns:
float: Maximum value of the response signal.
"""
return float(np.max(self.response))
def _find_minimum(self) -> float:
"""Find the minimum response amplitude.
Returns:
float: Minimum value of the response signal.
"""
return float(np.min(self.response))
def _find_time_to_peak(self) -> float:
"""Find the time to response peak.
Returns:
float: Time to peak in seconds.
"""
max_index = np.argmax(self.response)
return float(max_index * self.sampling_time)
def _find_mean(self) -> float:
"""Calculate the mean response amplitude.
Returns:
float: Mean value of the response signal.
"""
return float(np.mean(self.response))
def _area_under_the_curve(self, bin_size: int = 5) -> float:
"""Calculate the total area under the response curve.
Args:
bin_size: Number of samples to bin for integration.
Returns:
float: Area under the curve in amplitude-seconds.
"""
undersampling = self.response[::bin_size]
return float(np.trapz(undersampling, dx=self.sampling_time*bin_size))
def _absolute_area_under_the_curve(self, bin_size: int = 5) -> float:
"""Calculate the absolute area under the response curve.
Args:
bin_size: Number of samples to bin for integration.
Returns:
float: Absolute area under the curve in amplitude-seconds.
"""
undersampling = np.abs(self.response[::bin_size])
return float(np.trapz(undersampling, dx=self.sampling_time*bin_size))
def _area_under_the_curve_to_peak(self, bin_size: int = 10) -> float:
"""Calculate the area under the curve up to the peak response.
Args:
bin_size: Number of samples to bin for integration.
Returns:
float: Area under the curve to peak in amplitude-seconds.
"""
undersampling = self.response[::bin_size]
max_index = np.argmax(undersampling)
window_to_peak = undersampling[:max_index+1]
return float(np.trapz(window_to_peak, dx=self.sampling_time*bin_size))
def _find_onset_time(self, window_size: int = 10, threshold_std: float = 2.0,
absolute_threshold: Optional[float] = None) -> float:
"""Find the response onset time using a sliding window approach.
For calcium imaging data, this method uses a robust detection approach that:
1. Requires the signal to stay above threshold for the entire window
2. Checks that the signal continues to rise after the window
3. Ensures the detected onset is not just a noise spike
4. Requires a moderate increase in signal level
Args:
window_size: Size of the sliding window in samples.
threshold_std: Number of standard deviations above baseline for onset.
Only used if absolute_threshold is None.
absolute_threshold: Absolute threshold value for onset detection.
If provided, overrides threshold_std calculation.
Returns:
float: Onset time in seconds.
"""
baseline_mean = np.mean(self.baseline)
baseline_std = np.std(self.baseline)
# Calculate threshold based on provided method
if absolute_threshold is not None:
threshold = absolute_threshold
else:
# Use statistical threshold
threshold = baseline_mean + threshold_std * baseline_std
# Use a sliding window to find when response consistently exceeds threshold
for i in range(len(self.response) - window_size):
window = self.response[i:i+window_size]
if np.mean(window) > threshold:
return float(i * self.sampling_time)
# If no onset found, return nan
return np.nan
def _find_positive_area(self, bin_size: int = 10) -> tuple[float, float]:
"""Calculate the positive and negative areas of the response.
Args:
bin_size: Number of samples to bin for integration.
Returns:
tuple: (positive_area, negative_area) in amplitude-seconds.
"""
undersampling = self.response[::bin_size]
pos_mask = undersampling > 0
neg_mask = undersampling < 0
pos_area = float(np.trapz(undersampling[pos_mask],
dx=self.sampling_time*bin_size))
neg_area = float(abs(np.trapz(undersampling[neg_mask],
dx=self.sampling_time*bin_size)))
return pos_area, neg_area
[docs]
def get_response_characteristics(self) -> Dict[str, float]:
"""Calculate comprehensive response characteristics.
Returns:
dict: Dictionary containing various response metrics:
- amplitude: Peak response amplitude (relative to baseline)
- duration: Response duration
- latency: Response latency
- integral: Total response integral
- baseline_mean: Mean baseline activity
- baseline_std: Baseline standard deviation
- signal_to_noise: Signal-to-noise ratio
"""
baseline_mean = np.mean(self.baseline)
baseline_std = np.std(self.baseline)
# Calculate peak amplitude relative to baseline
# First subtract baseline from the entire response
response_minus_baseline = self.response - baseline_mean
# Find the maximum absolute deviation
max_abs_idx = np.argmax(np.abs(response_minus_baseline))
# Use the actual value at that index
peak_amplitude = response_minus_baseline[max_abs_idx]
onset_time = self._find_onset_time()
# Find response end (when signal returns to baseline)
end_threshold = baseline_mean + 2 * baseline_std
response_end = np.where(self.response <= end_threshold)[0]
response_end = response_end[-1] if len(response_end) > 0 else len(self.response)
duration = (response_end - onset_time) * self.sampling_time if not np.isnan(onset_time) else 0.0
integral = self._area_under_the_curve()
# Calculate signal-to-noise ratio
signal_power = np.mean(np.square(self.response - baseline_mean))
noise_power = np.mean(np.square(self.baseline - baseline_mean))
snr = 10 * np.log10(signal_power / noise_power) if noise_power > 0 else float('inf')
return {
'amplitude': float(peak_amplitude),
'duration': float(duration),
'latency': float(onset_time),
'integral': float(integral),
'baseline_mean': float(baseline_mean),
'baseline_std': float(baseline_std),
'signal_to_noise': float(snr)
}
def _linear_transform(value, minvalue, maxvalue):
return (value - minvalue)/(maxvalue - minvalue)