diff --git a/ecephys_spike_sorting/common/utils.py b/ecephys_spike_sorting/common/utils.py index c1e0fd2b..1012e0a5 100644 --- a/ecephys_spike_sorting/common/utils.py +++ b/ecephys_spike_sorting/common/utils.py @@ -313,6 +313,26 @@ def load_kilosort_data(folder, else: return spike_times, spike_clusters, spike_templates, amplitudes, unwhitened_temps, channel_map, cluster_ids, cluster_quality, pc_features, pc_feature_ind +def load_channel_positions(folder): + """ + Loads Kilosort output files from a directory + + Inputs: + ------- + folder : String + Location of Kilosort output directory + + Outputs: + -------- + channel_positions : numpy.ndarray + x,y positions of each channel + """ + + channel_positions = np.squeeze(load(folder,'channel_positions.npy')) + + return channel_positions + + def get_spike_depths(spike_templates, pc_features, pc_feature_ind): diff --git a/ecephys_spike_sorting/modules/mean_waveforms/__main__.py b/ecephys_spike_sorting/modules/mean_waveforms/__main__.py index e9a13887..99a79320 100644 --- a/ecephys_spike_sorting/modules/mean_waveforms/__main__.py +++ b/ecephys_spike_sorting/modules/mean_waveforms/__main__.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from ...common.utils import load_kilosort_data +from ...common.utils import load_kilosort_data,load_channel_positions from .extract_waveforms import extract_waveforms, writeDataAsNpy from .waveform_metrics import calculate_waveform_metrics @@ -26,6 +26,8 @@ def calculate_mean_waveforms(args): load_kilosort_data(args['directories']['kilosort_output_directory'], \ args['ephys_params']['sample_rate'], \ convert_to_seconds = False) + + channel_positions = load_channel_positions(args['directories']['kilosort_output_directory']) print("Calculating mean waveforms...") @@ -34,6 +36,7 @@ def calculate_mean_waveforms(args): spike_templates, templates, channel_map, + channel_positions, args['ephys_params']['bit_volts'], \ args['ephys_params']['sample_rate'], \ args['ephys_params']['vertical_site_spacing'], \ diff --git a/ecephys_spike_sorting/modules/mean_waveforms/extract_waveforms.py b/ecephys_spike_sorting/modules/mean_waveforms/extract_waveforms.py index 4e5d9812..afbf942f 100644 --- a/ecephys_spike_sorting/modules/mean_waveforms/extract_waveforms.py +++ b/ecephys_spike_sorting/modules/mean_waveforms/extract_waveforms.py @@ -16,7 +16,8 @@ def extract_waveforms(raw_data, spike_clusters, spike_templates, templates, - channel_map, + channel_map, + channel_positions, bit_volts, sample_rate, site_spacing, @@ -130,6 +131,7 @@ def extract_waveforms(raw_data, cluster_id, peak_channels[target_template_id], channel_map, + channel_positions, sample_rate, upsampling_factor, spread_threshold, diff --git a/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py b/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py index c6067723..322d5672 100644 --- a/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py +++ b/ecephys_spike_sorting/modules/mean_waveforms/waveform_metrics.py @@ -9,6 +9,7 @@ def calculate_waveform_metrics(waveforms, cluster_id, peak_channel, channel_map, + channel_positions, sample_rate, upsampling_factor, spread_threshold, @@ -75,7 +76,7 @@ def calculate_waveform_metrics(waveforms, mean_1D_waveform, timestamps) recovery_slope = calculate_waveform_recovery_slope( mean_1D_waveform, timestamps) - + site_spacing = (channel_positions[2,1] - channel_positions[0,1])/2 * 10e-7 # calculate site spacing, compatible for both npx 1.0 and npx 2.0 amplitude, spread, velocity_above, velocity_below = calculate_2D_features( mean_2D_waveform, timestamps, local_peak, spread_threshold, site_range, site_spacing)