Skip to content

Commit 3d3a2f1

Browse files
HAM41claude
andcommitted
Extend GLM analysis with glmnet implementation and enhanced DV extraction
- Add comprehensive GLM analysis using glmnet for clicks-to-DV modeling - Implement basis functions analysis for temporal feature extraction - Extend decision variables extraction with additional analysis methods - Update Adrian comparison notebook with expanded visualizations - Add MATLAB utilities for enhanced source data extraction - Update dependencies with scikit-learn and related packages - Document GLM planning and methodology improvements 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent ac06584 commit 3d3a2f1

8 files changed

Lines changed: 63634 additions & 5479 deletions

notebooks/analysis/decision_variables_adrian_comparison.ipynb

Lines changed: 54934 additions & 16 deletions
Large diffs are not rendered by default.

notebooks/analysis/decision_variables_extraction_v003.ipynb

Lines changed: 5532 additions & 2052 deletions
Large diffs are not rendered by default.

notebooks/analysis/glm_basis_functions_analysis.md

Lines changed: 572 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/analysis/glm_clicks_to_dv_glmnet.ipynb

Lines changed: 1932 additions & 0 deletions
Large diffs are not rendered by default.

notebooks/analysis/glm_clicks_to_dv_plan.md

Lines changed: 209 additions & 159 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = [
1010
{name = "Brody-Daw Lab"}
1111
]
1212
readme = "README.md"
13-
requires-python = ">=3.8"
13+
requires-python = ">=3.10"
1414
classifiers = [
1515
"Development Status :: 3 - Alpha",
1616
"Intended Audience :: Science/Research",
@@ -23,8 +23,8 @@ classifiers = [
2323
]
2424

2525
dependencies = [
26-
"numpy>=1.21.0",
27-
"scipy>=1.7.0",
26+
"numpy>=1.26.0",
27+
"scipy>=1.11.0",
2828
"pandas>=1.3.0",
2929
"matplotlib>=3.4.0",
3030
"seaborn>=0.11.0",
@@ -35,8 +35,8 @@ dependencies = [
3535
"pyyaml>=5.4.0",
3636
"jupyter>=1.1.1",
3737
"ipykernel>=6.29.5",
38-
"glmnet-py>=0.1.0b2",
3938
"tables>=3.8.0",
39+
"python-glmnet>=2.6.0",
4040
]
4141

4242
[project.optional-dependencies]

src/utils/matlab/extract_source_data.m

Lines changed: 86 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -39,16 +39,9 @@
3939
cells = source_data.Cells;
4040
fprintf('Found Cells array with dimensions: [%s]\n', num2str(size(cells)));
4141

42-
% Extract neural and session data
43-
fprintf('\n=== Extracting Neural Data ===\n');
44-
[neural_data, session_data] = extract_neural_data(cells);
45-
46-
% Extract trial data
47-
fprintf('\n=== Extracting Trial Data ===\n');
48-
trial_data = extract_trial_data(cells);
49-
50-
% Combine all extracted data
51-
extracted_data = merge_structures(neural_data, session_data, trial_data);
42+
% Extract and concatenate raw data across all probes
43+
fprintf('\n=== Extracting and Concatenating Raw Data ===\n');
44+
extracted_data = extract_and_concatenate_raw_data(cells);
5245

5346
% Display summary
5447
display_extraction_summary(extracted_data);
@@ -62,30 +55,18 @@
6255

6356
end
6457

65-
function [neural_data, session_data] = extract_neural_data(cells)
66-
% Extract spike times, regions, hemispheres, and quality metrics from Cells structure
58+
function extracted_data = extract_and_concatenate_raw_data(cells)
59+
% Extract and concatenate all raw fields across probes without processing
6760

68-
neural_data = struct();
69-
session_data = struct();
61+
extracted_data = struct();
7062

71-
% Initialize collections - raw data
63+
% Initialize concatenation arrays
7264
all_spike_times = {};
7365
all_hemispheres = {};
7466
all_regions = {};
75-
all_electrodes = {};
76-
77-
% Initialize collections - quality metrics
78-
all_quality_metrics = struct();
79-
all_quality_metrics.spatial_spread_um = [];
80-
all_quality_metrics.peak_width_ms = [];
81-
all_quality_metrics.peak_trough_width_ms = [];
82-
all_quality_metrics.upward_going = [];
83-
all_quality_metrics.uvpp = [];
84-
85-
% Initialize collections - filtered data
86-
filt_spike_times = {};
87-
filt_hemispheres = {};
88-
filt_regions = {};
67+
all_raw_quality_metrics = {};
68+
all_coordinates = [];
69+
all_other_unit_fields = containers.Map();
8970

9071
% Process each probe
9172
num_probes = size(cells, 2);
@@ -94,155 +75,133 @@
9475
for probe_idx = 1:num_probes
9576
fprintf(' Probe %d: ', probe_idx);
9677

97-
% Use array indexing instead of cell indexing
9878
cell_data = cells(1, probe_idx);
9979

100-
% Debug: check what we got
101-
fprintf('type=%s, ', class(cell_data));
102-
10380
if isempty(cell_data)
10481
fprintf('empty\n');
10582
continue;
10683
end
10784

108-
% Navigate to the cell structure - handle different possible structures
85+
% Navigate to the cell structure
10986
try
11087
if iscell(cell_data) && ~isempty(cell_data)
11188
cell_struct = cell_data{1,1};
11289
else
113-
% Try direct array access
11490
cell_struct = cell_data(1,1);
11591
end
11692
catch
11793
fprintf('cannot access structure\n');
11894
continue;
11995
end
12096

121-
% Count units with spike data
97+
% Count units and concatenate spike times
12298
n_units = 0;
12399
if isfield(cell_struct, 'raw_spike_time_s')
124100
spike_times = cell_struct.raw_spike_time_s;
125-
126-
% Count non-empty spike time arrays
127101
for i = 1:numel(spike_times)
128102
if ~isempty(spike_times{i})
129103
n_units = n_units + 1;
130-
all_spike_times{end+1} = spike_times{i}(:); % Ensure column vector
104+
all_spike_times{end+1} = spike_times{i}(:);
131105
end
132106
end
133107
end
134108

135-
fprintf('%d units with spikes\n', n_units);
136-
137-
% Extract hemisphere and region data
138-
hemisphere_str = '';
139-
if isfield(cell_struct, 'hemisphere')
140-
hemisphere_obj = cell_struct.hemisphere;
141-
hemisphere_str = extract_mcos_string(hemisphere_obj, 'hemisphere');
142-
end
143-
144-
region_str = '';
145-
if isfield(cell_struct, 'region')
146-
region_obj = cell_struct.region;
147-
region_str = extract_mcos_string(region_obj, 'region');
148-
end
109+
fprintf('%d units\n', n_units);
149110

150-
% Extract quality metrics for this probe
151-
probe_quality = extract_quality_metrics(cell_struct);
111+
% Concatenate metadata fields for each unit
112+
unit_level_fields = {'hemisphere', 'region', 'electrode', 'quality_metrics', ...
113+
'waveform', 'distance_from_tip', 'included_units', ...
114+
'frac_spikes_removed'};
152115

153-
% Process each unit - just extract data without filtering
154-
unit_idx = 0;
155-
if isfield(cell_struct, 'raw_spike_time_s')
156-
spike_times = cell_struct.raw_spike_time_s;
157-
158-
for i = 1:numel(spike_times)
159-
if ~isempty(spike_times{i})
160-
unit_idx = unit_idx + 1;
161-
162-
% Add to raw collections
163-
all_hemispheres{end+1} = hemisphere_str;
164-
all_regions{end+1} = region_str;
165-
all_electrodes{end+1} = [];
166-
167-
% Add quality metrics for this specific unit
168-
if ~isempty(probe_quality) && unit_idx <= length(probe_quality.spatial_spread_um)
169-
all_quality_metrics.spatial_spread_um(end+1) = probe_quality.spatial_spread_um(unit_idx);
170-
all_quality_metrics.peak_width_ms(end+1) = probe_quality.peak_width_ms(unit_idx);
171-
all_quality_metrics.peak_trough_width_ms(end+1) = probe_quality.peak_trough_width_ms(unit_idx);
172-
all_quality_metrics.upward_going(end+1) = probe_quality.upward_going(unit_idx);
173-
all_quality_metrics.uvpp(end+1) = probe_quality.uvpp(unit_idx);
174-
else
175-
% No quality metrics available - fill with NaN/default values
176-
all_quality_metrics.spatial_spread_um(end+1) = NaN;
177-
all_quality_metrics.peak_width_ms(end+1) = NaN;
178-
all_quality_metrics.peak_trough_width_ms(end+1) = NaN;
179-
all_quality_metrics.upward_going(end+1) = false;
180-
all_quality_metrics.uvpp(end+1) = NaN;
116+
for field_name = unit_level_fields
117+
field = field_name{1};
118+
if isfield(cell_struct, field)
119+
field_data = cell_struct.(field);
120+
121+
% Store raw field data (replicated for each unit if needed)
122+
if strcmp(field, 'hemisphere') || strcmp(field, 'region')
123+
% String fields - extract and replicate
124+
str_val = extract_mcos_string(field_data, field);
125+
for u = 1:n_units
126+
if strcmp(field, 'hemisphere')
127+
all_hemispheres{end+1} = str_val;
128+
elseif strcmp(field, 'region')
129+
all_regions{end+1} = str_val;
130+
end
131+
end
132+
else
133+
% Store raw data for Python processing
134+
if ~all_other_unit_fields.isKey(field)
135+
all_other_unit_fields(field) = {};
181136
end
137+
unit_data = all_other_unit_fields(field);
138+
unit_data{end+1} = field_data; % Store probe-level data
139+
all_other_unit_fields(field) = unit_data;
182140
end
183141
end
184142
end
185143

144+
% Store coordinates (probe-level)
145+
probe_coords = struct();
146+
coord_fields = {'AP', 'ML', 'DV'};
147+
for coord_field = coord_fields
148+
field = coord_field{1};
149+
if isfield(cell_struct, field)
150+
probe_coords.(field) = cell_struct.(field);
151+
end
152+
end
153+
all_coordinates = [all_coordinates; probe_coords];
154+
186155
% Extract session metadata from first probe only
187156
if probe_idx == 1
188-
session_fields = {'nTrials', 'removed_trials', 'sessid', 'sess_date', 'rat'};
157+
session_fields = {'nTrials', 'removed_trials', 'sessid', 'sess_date', 'rat', ...
158+
'bank', 'penetration', 'rec', 'shank', 'probe_serial'};
189159

190-
for field_idx = 1:length(session_fields)
191-
field_name = session_fields{field_idx};
192-
if isfield(cell_struct, field_name)
193-
field_data = cell_struct.(field_name);
160+
for field_name = session_fields
161+
field = field_name{1};
162+
if isfield(cell_struct, field)
163+
field_data = cell_struct.(field);
194164

195-
% Handle string fields (potentially MCOS)
196-
if ismember(field_name, {'sess_date', 'rat'})
197-
session_data.(field_name) = extract_mcos_string(field_data, field_name);
165+
% Handle string fields
166+
if ismember(field, {'sess_date', 'rat', 'probe_serial'})
167+
extracted_data.(field) = extract_mcos_string(field_data, field);
198168
else
199-
session_data.(field_name) = field_data;
169+
extracted_data.(field) = field_data;
200170
end
201171
end
202172
end
173+
174+
% Extract trial data from first probe
175+
if isfield(cell_struct, 'Trials')
176+
trials = cell_struct.Trials;
177+
if ~isempty(trials)
178+
fprintf(' Extracting trial data...\n');
179+
extracted_data = extract_raw_trial_data(trials, extracted_data);
180+
end
181+
end
203182
end
204183
end
205184

206-
% Format neural data arrays
185+
% Format concatenated neural data
207186
if ~isempty(all_spike_times)
208-
% Raw data (unfiltered)
209-
neural_data.raw_spike_time_s = all_spike_times';
210-
neural_data.hemisphere = all_hemispheres;
211-
neural_data.region = all_regions';
212-
neural_data.electrode = []; % Empty as in source
187+
extracted_data.raw_spike_time_s = all_spike_times';
188+
extracted_data.hemisphere = all_hemispheres;
189+
extracted_data.region = all_regions';
190+
extracted_data.electrode = []; % Empty placeholder
213191

214-
% Quality metrics for all neurons
215-
if ~isempty(all_quality_metrics.spatial_spread_um)
216-
neural_data.quality_spatial_spread_um = all_quality_metrics.spatial_spread_um';
217-
neural_data.quality_peak_width_ms = all_quality_metrics.peak_width_ms';
218-
neural_data.quality_peak_trough_width_ms = all_quality_metrics.peak_trough_width_ms';
219-
neural_data.quality_upward_going = all_quality_metrics.upward_going';
220-
neural_data.quality_uvpp = all_quality_metrics.uvpp';
221-
222-
fprintf(' Total neurons: %d (with quality metrics)\n', length(all_spike_times));
223-
224-
% Add quality metrics summary
225-
valid_spatial = ~isnan(all_quality_metrics.spatial_spread_um);
226-
valid_uvpp = ~isnan(all_quality_metrics.uvpp);
227-
228-
if any(valid_spatial)
229-
fprintf(' Quality metrics summary:\n');
230-
fprintf(' Spatial spread: %.1f±%.1f μm (n=%d)\n', ...
231-
mean(all_quality_metrics.spatial_spread_um(valid_spatial)), ...
232-
std(all_quality_metrics.spatial_spread_um(valid_spatial)), ...
233-
sum(valid_spatial));
234-
fprintf(' Peak width: %.2f±%.2f ms\n', mean(all_quality_metrics.peak_width_ms(valid_spatial)), std(all_quality_metrics.peak_width_ms(valid_spatial)));
235-
fprintf(' Peak-trough width: %.2f±%.2f ms\n', mean(all_quality_metrics.peak_trough_width_ms(valid_spatial)), std(all_quality_metrics.peak_trough_width_ms(valid_spatial)));
236-
fprintf(' uVpp: %.1f±%.1f μV\n', mean(all_quality_metrics.uvpp(valid_uvpp)), std(all_quality_metrics.uvpp(valid_uvpp)));
237-
fprintf(' Upward-going: %d/%d (%.1f%%)\n', sum(all_quality_metrics.upward_going), length(all_quality_metrics.upward_going), 100*mean(all_quality_metrics.upward_going));
238-
end
239-
else
240-
fprintf(' Total neurons: %d (no quality metrics extracted)\n', length(all_spike_times));
192+
% Store other raw unit-level data for Python processing
193+
unit_field_keys = keys(all_other_unit_fields);
194+
for i = 1:length(unit_field_keys)
195+
field = unit_field_keys{i};
196+
extracted_data.(field) = all_other_unit_fields(field);
241197
end
242-
else
243-
fprintf(' No neural data found!\n');
198+
199+
fprintf(' Total concatenated neurons: %d\n', length(all_spike_times));
244200
end
245201

202+
% Store probe coordinates
203+
extracted_data.probe_coordinates = all_coordinates;
204+
246205
end
247206

248207
function trial_data = extract_trial_data(cells)

0 commit comments

Comments
 (0)