"""
Trial-wise neural activity and stimulus-response analysis for CeDNe.
This module defines classes for capturing and analyzing experimental recordings
from neurons. It includes:
- `Trial`: Represents a single experimental recording for a neuron, such as
a calcium imaging time series. Trials are stored per-neuron in the `Neuron.trial`
dictionary and support signal preprocessing (e.g., bleaching correction).
- `StimResponse`: Encapsulates a stimulus-response pair recorded during a `Trial`,
and extracts a set of interpretable features from the response signal, including
max amplitude, onset time, area under the curve, and others.
These classes are designed to support time-locked calcium imaging experiments
and help link dynamic neural activity to behavioral or stimulus-driven contexts.
"""
__author__ = "Sahil Moza"
__date__ = "2025-04-06"
__license__ = "MIT"
from typing import Optional, List, Union, Dict, Any, TYPE_CHECKING
from datetime import datetime
import numpy as np
from scipy import signal
import scipy.stats as ss
from scipy.ndimage import gaussian_filter1d
from numpy.typing import NDArray
from .config import F_SAMPLE
if TYPE_CHECKING:
from .neuron import Neuron
[docs]
class Trial:
"""A class representing a single experimental recording for a neuron.
This class handles the storage and basic processing of time series data recorded
from a neuron during an experimental trial. It supports operations like data
storage, bleaching correction, and basic signal processing.
Attributes:
neuron: The neuron object associated with this trial.
i (int): The trial number.
discard (List[int]): Points to be discarded due to bleaching or artifacts.
_data (NDArray): The actual recording data.
metadata (Dict): Dictionary containing trial metadata.
"""
def __init__(self, neuron: "Neuron", trialnum: int) -> None:
"""Initialize a new Trial instance.
Args:
neuron: The neuron object associated with this trial.
trialnum: The trial number identifier.
"""
self.neuron = neuron
self.i = trialnum
self.discard: List[int] = []
self._data: Optional[NDArray] = None
self.behavior = None # Optional back-reference to Behavior object
self.metadata: Dict[str, Any] = {
"trial_number": trialnum,
"neuron_id": id(neuron),
"sampling_rate": F_SAMPLE,
"processing_history": [],
}
@property
def recording(self) -> NDArray:
"""Get the recording data for the trial.
Returns:
NDArray: The recording time series data.
Raises:
ValueError: If no recording data has been set.
"""
if self._data is None:
raise ValueError("No recording data has been set")
return self._data
@recording.setter
def recording(self, data: NDArray, discard: float = 0) -> None:
"""Set the recording data for the trial.
Args:
data: The time series data to be recorded.
discard: Number of initial seconds to discard (e.g., for bleaching correction).
Raises:
ValueError: If discard is negative or if data is invalid.
"""
if not isinstance(data, np.ndarray):
data = np.array(data)
if data.ndim != 1:
raise ValueError("Recording data must be a 1D array")
if discard < 0:
raise ValueError("Discard cannot be negative")
if discard > 0:
discard_points = int(discard * F_SAMPLE)
if discard_points >= len(data):
raise ValueError("Discard duration exceeds data length")
self.discard = list(range(discard_points))
self._data = data[discard_points:].astype(np.float64)
else:
self.discard = []
self._data = data.astype(np.float64)
[docs]
def get_duration(self) -> float:
"""Get the duration of the recording in seconds.
Returns:
float: Duration of the recording in seconds.
Raises:
ValueError: If no recording data has been set.
"""
return len(self.recording) / F_SAMPLE
[docs]
def get_timestamps(self) -> NDArray:
"""Get the timestamps for each sample in the recording.
Returns:
NDArray: Array of timestamps in seconds.
Raises:
ValueError: If no recording data has been set.
"""
return np.arange(len(self.recording)) / F_SAMPLE
[docs]
def filter_signal(
self, filter_type: str = "lowpass", cutoff_freq: float = 10.0, order: int = 4
) -> NDArray:
"""Apply a Butterworth filter to the recording data.
Args:
filter_type: Type of filter ('lowpass', 'highpass', or 'bandpass').
cutoff_freq: Cutoff frequency in Hz. For bandpass, provide tuple (low, high).
order: Order of the Butterworth filter.
Returns:
NDArray: Filtered signal.
Raises:
ValueError: If filter_type is invalid or if no recording data has been set.
"""
nyquist = F_SAMPLE / 2
if isinstance(cutoff_freq, (list, tuple)):
cutoff_freq = np.array(cutoff_freq)
if np.any(cutoff_freq <= 0) or np.any(cutoff_freq >= nyquist):
raise ValueError(
"Cutoff frequencies must be between 0 and nyquist frequency"
)
else:
if cutoff_freq <= 0 or cutoff_freq >= nyquist:
raise ValueError(
"Cutoff frequency must be between 0 and nyquist frequency"
)
normalized_cutoff = cutoff_freq / nyquist
if filter_type == "lowpass":
b, a = signal.butter(order, normalized_cutoff, btype="low")
elif filter_type == "highpass":
b, a = signal.butter(order, normalized_cutoff, btype="high")
elif filter_type == "bandpass":
if (
not isinstance(normalized_cutoff, (list, tuple, np.ndarray))
or len(normalized_cutoff) != 2
):
raise ValueError("Bandpass filter requires two cutoff frequencies")
b, a = signal.butter(order, normalized_cutoff, btype="band")
else:
raise ValueError(f"Invalid filter type: {filter_type}")
return signal.filtfilt(b, a, self.recording)
[docs]
def smooth_signal(self, window_size: int = 5, method: str = "moving") -> NDArray:
"""Smooth the recording signal using various methods.
Args:
window_size: Size of the smoothing window in samples.
method: Smoothing method ('moving', 'gaussian', or 'median').
Returns:
NDArray: Smoothed signal.
Raises:
ValueError: If method is invalid or if no recording data has been set.
"""
if method == "moving":
window = np.ones(window_size) / window_size
return np.convolve(self.recording, window, mode="same")
elif method == "gaussian":
return gaussian_filter1d(self.recording, window_size)
elif method == "median":
return signal.medfilt(self.recording, window_size)
else:
raise ValueError(f"Invalid smoothing method: {method}")
[docs]
def normalize_signal(
self, method: str = "minmax", baseline_window: Optional[tuple] = None
) -> NDArray:
"""Normalize the recording signal using various methods.
Args:
method: Normalization method ('minmax', 'zscore', or 'baseline').
baseline_window: Tuple of (start, end) indices for baseline normalization.
Returns:
NDArray: Normalized signal.
Raises:
ValueError: If method is invalid or if no recording data has been set.
"""
if method == "minmax":
min_val = np.min(self.recording)
max_val = np.max(self.recording)
return (self.recording - min_val) / (max_val - min_val)
elif method == "zscore":
return (self.recording - np.mean(self.recording)) / np.std(self.recording)
elif method == "baseline":
if baseline_window is None:
raise ValueError(
"baseline_window must be provided for baseline normalization"
)
start, end = baseline_window
baseline = np.mean(self.recording[start:end])
return (self.recording - baseline) / baseline
else:
raise ValueError(f"Invalid normalization method: {method}")
[docs]
def detect_peaks(
self, height: Optional[float] = None, distance: Optional[int] = None
) -> tuple[NDArray, NDArray]:
"""Detect peaks in the recording signal.
Args:
height: Minimum height of peaks.
distance: Minimum distance between peaks in samples.
Returns:
tuple: (peak_indices, peak_heights)
Raises:
ValueError: If no recording data has been set.
"""
if distance is not None and distance < 1:
raise ValueError("`distance` must be greater or equal to 1")
peaks, properties = signal.find_peaks(
self.recording, height=height, distance=distance
)
peak_heights = (
self.recording[peaks]
if "peak_heights" not in properties
else properties["peak_heights"]
)
return peaks, peak_heights
[docs]
def get_statistics(self) -> dict:
"""Calculate basic statistics of the recording.
Returns:
dict: Dictionary containing various statistical measures:
- mean: Mean of the signal
- std: Standard deviation
- median: Median value
- min: Minimum value
- max: Maximum value
- skewness: Skewness of the distribution
- kurtosis: Kurtosis of the distribution
- rms: Root mean square value
Raises:
ValueError: If no recording data has been set.
"""
return {
"mean": np.mean(self.recording),
"std": np.std(self.recording),
"median": np.median(self.recording),
"min": np.min(self.recording),
"max": np.max(self.recording),
"skewness": ss.skew(self.recording),
"kurtosis": ss.kurtosis(self.recording),
"rms": np.sqrt(np.mean(np.square(self.recording))),
}
[docs]
def compute_power_spectrum(self, window: str = "hann") -> tuple[NDArray, NDArray]:
"""Compute the power spectrum of the recording.
Args:
window: Window function to use ('hann', 'hamming', 'blackman', etc.).
Returns:
tuple: Arrays of frequencies and corresponding power spectrum.
Raises:
ValueError: If no recording data has been set.
"""
freqs, psd = signal.welch(self.recording, F_SAMPLE, window=window)
return freqs, psd
[docs]
def compute_snr(self, signal_window: tuple, noise_window: tuple) -> float:
"""Compute the signal-to-noise ratio.
Args:
signal_window: Tuple of (start, end) indices for signal region.
noise_window: Tuple of (start, end) indices for noise region.
Returns:
float: Signal-to-noise ratio in dB.
Raises:
ValueError: If windows are invalid or if no recording data has been set.
"""
sig_start, sig_end = signal_window
noise_start, noise_end = noise_window
signal_power = np.mean(np.square(self.recording[sig_start:sig_end]))
noise_power = np.mean(np.square(self.recording[noise_start:noise_end]))
if noise_power == 0:
return float("inf")
return 10 * np.log10(signal_power / noise_power)
[docs]
def segment_signal(
self, threshold: float, min_duration: int = 10
) -> List[tuple[int, int]]:
"""Segment the signal into regions above threshold.
Args:
threshold: Amplitude threshold for segmentation.
min_duration: Minimum duration (in samples) for a valid segment.
Returns:
List[tuple]: List of (start, end) indices for each segment.
Raises:
ValueError: If no recording data has been set.
"""
above_threshold = self.recording > threshold
changes = np.diff(above_threshold.astype(int))
rise_points = np.where(changes == 1)[0] + 1
fall_points = np.where(changes == -1)[0] + 1
# Handle edge cases
if len(rise_points) == 0 or len(fall_points) == 0:
return []
if rise_points[0] > fall_points[0]:
rise_points = np.insert(rise_points, 0, 0)
if rise_points[-1] > fall_points[-1]:
fall_points = np.append(fall_points, len(self.recording))
segments = []
for start, end in zip(rise_points, fall_points):
if end - start >= min_duration:
segments.append((start, end))
return segments
[docs]
def log_processing(self, operation: str, parameters: Dict[str, Any]) -> None:
"""Log a processing operation in the trial's history.
Args:
operation: Name of the processing operation.
parameters: Dictionary of parameters used in the operation.
"""
self.metadata["processing_history"].append(
{
"timestamp": datetime.now().isoformat(),
"operation": operation,
"parameters": parameters,
}
)
[docs]
def validate_data(self) -> bool:
"""Validate the recording data.
Returns:
bool: True if data is valid, False otherwise.
This method checks:
- Data is not None
- Data is a numpy array
- Data is 1-dimensional
- Data contains no NaN or infinite values
- Data length is reasonable (> 0)
"""
if self._data is None:
return False
if not isinstance(self._data, np.ndarray):
return False
if self._data.ndim != 1:
return False
if not np.isfinite(self._data).all():
return False
if len(self._data) == 0:
return False
return True
[docs]
def get_quality_metrics(self) -> Dict[str, float]:
"""Calculate quality metrics for the recording.
Returns:
dict: Dictionary containing quality metrics:
- snr: Signal-to-noise ratio
- noise_level: Estimated noise level
- signal_stability: Measure of signal stability
- artifact_count: Number of potential artifacts
Raises:
ValueError: If no recording data has been set.
"""
if not self.validate_data():
raise ValueError("Invalid or missing recording data")
# Calculate noise level from the first 10% of the signal
noise_window = slice(0, len(self._data) // 10)
noise_level = np.std(self._data[noise_window])
# Detect potential artifacts (points > 3 std from mean)
mean = np.mean(self._data)
std = np.std(self._data)
artifacts = np.sum(np.abs(self._data - mean) > 3 * std)
# Calculate signal stability (variation in signal segments)
segment_length = len(self._data) // 10
segments = np.array_split(self._data, 10)
segment_means = [np.mean(seg) for seg in segments]
stability = 1 - np.std(segment_means) / np.mean(segment_means)
return {
"noise_level": float(noise_level),
"artifact_count": int(artifacts),
"signal_stability": float(stability),
}
[docs]
class StimResponse:
"""A class representing a stimulus-response pair in a neural recording.
This class handles the analysis of neural responses to specific stimuli,
extracting various features from the response signal and providing methods
for response characterization.
Attributes:
stim (NDArray): The stimulus signal.
response (NDArray): The response signal.
feature (Dict[int, Any]): Dictionary of extracted features.
neuron: The neuron object associated with this response.
f_sample (float): Sampling frequency in Hz.
sampling_time (float): Time between samples in seconds.
baseline (NDArray): Baseline signal before stimulus.
Features extracted include:
0: Maximum response amplitude
1: Area under the curve
2: Mean response
3: Time to peak
4: Area under the curve to peak
5: Minimum response
6: Response onset time
7: Positive response area
8: Absolute area under the curve
"""
def __init__(
self, trial: Trial, stimulus: NDArray, response: NDArray, baseline_samples: int
) -> None:
"""Initialize a StimResponse instance.
Args:
trial: The trial object associated with this response.
stimulus: The stimulus signal.
response: The response signal.
baseline_samples: Number of samples to use for baseline calculation.
Raises:
ValueError: If input arrays have invalid dimensions or lengths.
"""
if not isinstance(stimulus, np.ndarray) or not isinstance(response, np.ndarray):
raise ValueError("Stimulus and response must be numpy arrays")
if stimulus.ndim != 1 or response.ndim != 1:
raise ValueError("Stimulus and response must be 1-dimensional")
if len(stimulus) != len(response):
raise ValueError("Stimulus and response must have the same length")
if baseline_samples >= len(response):
raise ValueError("Baseline samples exceeds response length")
self.stim = stimulus
self.response = response
self.feature: Dict[int, Any] = {}
self.neuron = trial.neuron
self.f_sample = F_SAMPLE
self.sampling_time = 1.0 / self.f_sample
self.baseline = self.response[:baseline_samples]
# Extract all features
for feature_index in range(9): # 9 features total
self.feature[feature_index] = self.extract_feature(feature_index)
def _find_maximum(self) -> float:
"""Find the maximum response amplitude.
Returns:
float: Maximum value of the response signal.
"""
return float(np.max(self.response))
def _find_minimum(self) -> float:
"""Find the minimum response amplitude.
Returns:
float: Minimum value of the response signal.
"""
return float(np.min(self.response))
def _find_time_to_peak(self) -> float:
"""Find the time to response peak.
Returns:
float: Time to peak in seconds.
"""
max_index = np.argmax(self.response)
return float(max_index * self.sampling_time)
def _find_mean(self) -> float:
"""Calculate the mean response amplitude.
Returns:
float: Mean value of the response signal.
"""
return float(np.mean(self.response))
def _area_under_the_curve(self, bin_size: int = 5) -> float:
"""Calculate the total area under the response curve.
Args:
bin_size: Number of samples to bin for integration.
Returns:
float: Area under the curve in amplitude-seconds.
"""
undersampling = self.response[::bin_size]
return float(np.trapz(undersampling, dx=self.sampling_time * bin_size))
def _absolute_area_under_the_curve(self, bin_size: int = 5) -> float:
"""Calculate the absolute area under the response curve.
Args:
bin_size: Number of samples to bin for integration.
Returns:
float: Absolute area under the curve in amplitude-seconds.
"""
undersampling = np.abs(self.response[::bin_size])
return float(np.trapz(undersampling, dx=self.sampling_time * bin_size))
def _area_under_the_curve_to_peak(self, bin_size: int = 10) -> float:
"""Calculate the area under the curve up to the peak response.
Args:
bin_size: Number of samples to bin for integration.
Returns:
float: Area under the curve to peak in amplitude-seconds.
"""
undersampling = self.response[::bin_size]
max_index = np.argmax(undersampling)
window_to_peak = undersampling[: max_index + 1]
return float(np.trapz(window_to_peak, dx=self.sampling_time * bin_size))
def _find_onset_time(
self,
window_size: int = 10,
threshold_std: float = 2.0,
absolute_threshold: Optional[float] = None,
) -> float:
"""Find the response onset time using a sliding window approach.
For calcium imaging data, this method uses a robust detection approach that:
1. Requires the signal to stay above threshold for the entire window
2. Checks that the signal continues to rise after the window
3. Ensures the detected onset is not just a noise spike
4. Requires a moderate increase in signal level
Args:
window_size: Size of the sliding window in samples.
threshold_std: Number of standard deviations above baseline for onset.
Only used if absolute_threshold is None.
absolute_threshold: Absolute threshold value for onset detection.
If provided, overrides threshold_std calculation.
Returns:
float: Onset time in seconds.
"""
baseline_mean = np.mean(self.baseline)
baseline_std = np.std(self.baseline)
# Calculate threshold based on provided method
if absolute_threshold is not None:
threshold = absolute_threshold
else:
# Use statistical threshold
threshold = baseline_mean + threshold_std * baseline_std
# Use a sliding window to find when response consistently exceeds threshold
for i in range(len(self.response) - window_size):
window = self.response[i : i + window_size]
if np.mean(window) > threshold:
return float(i * self.sampling_time)
# If no onset found, return nan
return np.nan
def _find_positive_area(self, bin_size: int = 5) -> tuple[float, float]:
"""Calculate the positive and negative areas of the response.
Args:
bin_size: Number of samples to bin for integration.
Returns:
tuple: (positive_area, negative_area) in amplitude-seconds.
"""
undersampling = self.response[::bin_size]
positive_trace = np.clip(undersampling, 0, None)
negative_trace = np.clip(-undersampling, 0, None)
dx = self.sampling_time * bin_size
pos_area = float(np.trapz(positive_trace, dx=dx))
neg_area = float(np.trapz(negative_trace, dx=dx))
abs_area = float(np.trapz(np.abs(undersampling), dx=dx))
signed_total = pos_area + neg_area
if signed_total > abs_area and np.isclose(signed_total, abs_area):
scale = abs_area / signed_total
pos_area *= scale
neg_area *= scale
return pos_area, neg_area
[docs]
def get_response_characteristics(self) -> Dict[str, float]:
"""Calculate comprehensive response characteristics.
Returns:
dict: Dictionary containing various response metrics:
- amplitude: Peak response amplitude (relative to baseline)
- duration: Response duration
- latency: Response latency
- integral: Total response integral
- baseline_mean: Mean baseline activity
- baseline_std: Baseline standard deviation
- signal_to_noise: Signal-to-noise ratio
"""
baseline_mean = np.mean(self.baseline)
baseline_std = np.std(self.baseline)
# Calculate peak amplitude relative to baseline
# First subtract baseline from the entire response
response_minus_baseline = self.response - baseline_mean
# Find the maximum absolute deviation
max_abs_idx = np.argmax(np.abs(response_minus_baseline))
# Use the actual value at that index
peak_amplitude = response_minus_baseline[max_abs_idx]
onset_time = self._find_onset_time()
# Find response end (when signal returns to baseline)
end_threshold = baseline_mean + 2 * baseline_std
response_end = np.where(self.response <= end_threshold)[0]
response_end = response_end[-1] if len(response_end) > 0 else len(self.response)
duration = (
(response_end - onset_time) * self.sampling_time
if not np.isnan(onset_time)
else 0.0
)
integral = self._area_under_the_curve()
# Calculate signal-to-noise ratio
signal_power = np.mean(np.square(self.response - baseline_mean))
noise_power = np.mean(np.square(self.baseline - baseline_mean))
snr = (
10 * np.log10(signal_power / noise_power)
if noise_power > 0
else float("inf")
)
return {
"amplitude": float(peak_amplitude),
"duration": float(duration),
"latency": float(onset_time),
"integral": float(integral),
"baseline_mean": float(baseline_mean),
"baseline_std": float(baseline_std),
"signal_to_noise": float(snr),
}
def _linear_transform(value, minvalue, maxvalue):
return (value - minvalue) / (maxvalue - minvalue)