-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathConvertPhyKilo2Neurosuite.m
More file actions
384 lines (330 loc) · 12.4 KB
/
ConvertPhyKilo2Neurosuite.m
File metadata and controls
384 lines (330 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
function ConvertPhyKilo2Neurosuite(basepath,basename,ks_basepath,varargin)
%
% ###
% .
% +-- _folder(basename)
% | +-- basename.dat
% | +-- basename.xml
% | +-- _Kilosort_date_time
% | | +-- spike_clusters.npy
% | | +-- spike_times.npy
% | | +-- cluster_group.tsv
% | | +-- pc_features.npy
% | | +-- templates.npy
% | | +-- rez.mat
% | | +-- cluster_ids.npy (ks1 only)
% | | +-- shanks.npy (ks1 only)
% | | +-- cluster_info.tsv (ks2 only)
%
% ###
%
% KILOSORT1/PHY1 USERS:
% cluster_ids.npy and shanks.npy are generated by a phy1 plugin
% -Export Shanks- found in this link https://github.com/petersenpeter/phy-plugins)
%
% KILOSORT2/PHY2 USERS:
% The output of kilosort2 and phy2 are enough for this code to run. Your
% structure of session/folders should be the same as shown above
%
% Inputs:
% basepath - directory path to the main recording folder with .dat and .xml
% as well as kilosort folder generated by Kilosortwrapper
% basename - shared file name of .dat and .xml (default is last part of
% current directory path, ie most immediate folder name)
% ks_basepath - directory containing
%
%
% Eliezyer de Oliveira, 2018
%
% -
% reviewed in 11/2019 - added the functionality to convert alsothe outputs of phy2
if ~exist('basepath','var')
[~,basename] = fileparts(cd);
basepath = cd;
end
savepath = basepath;
p = inputParser;
addParameter(p,'kilosort2',false,@islogical)
parse(p,varargin{:})
kilosort2 = p.Results.kilosort2;
%finding the last Kilosort folder in order
if~exist('ks_basepath','var')
KSdir = ks_basepath;
else
auxDir = dir;
auxKSD = find([auxDir.isdir]);
for i = auxKSD
if strfind(auxDir(i).name,'Kilosort')
KSdir = auxDir(i).name;
end
end
end
%loading phy files
% cd(KSdir);
if ~exist('rez','var')
load(fullfile(basepath,KSdir,'rez2.mat'))
end
% Nchan = rez.ops.Nchan;%not being used
% connected = rez.connected; %not being used
% xcoords = rez.xc;%not being used
% ycoords = rez.yc;%not being used
% Nchan = rez.ops.Nchan;
% connected = ones(Nchan, 1);
% xcoords = ones(Nchan, 1);
% ycoords = (1:Nchan)';
d = dir('*.xml');
if ~isempty(d)
par = LoadParameters(fullfile(basepath,d(1).name));
else
error('the .xml file is missing')
end
totalch = par.nChannels;
sbefore = 16;%samples before/after for spike extraction
safter = 16;%... could read from SpkGroups in xml
if isfield(par,'SpkGrps')
if isfield(par.SpkGrps,'nSamples')
if ~isempty(par.SpkGrps(1).nSamples);
if isfield(par.SpkGrps,'PeakSample')
if ~isempty(par.SpkGrps(1).PeakSample);
sbefore = par.SpkGrps(1).PeakSample;
safter = par.SpkGrps(1).nSamples - par.SpkGrps(1).PeakSample;
end
end
end
end
end
if exist(rez.ops.fbinary,'file')
datpath = rez.ops.fbinary;
else
datpath = fullfile(basepath,[basename '.dat']);
end
%% identify timestamps that are from good clusters
clusters = readNPY(fullfile(basepath,KSdir,'spike_clusters.npy'));
S = tdfread(fullfile(basepath,KSdir,'cluster_group.tsv'));
group = S.group;
cluster_id = S.cluster_id;
%getting good clusters only
GClusters = strfind(group(:,1)','g');
ExtClus = cluster_id(GClusters);
% Separating idx by cluster
auxiliarC = find(ismember(clusters,ExtClus));
%% getting spike information
spktimes = uint64(readNPY(fullfile(basepath,KSdir,'spike_times.npy')));
spktimes = spktimes(auxiliarC);
clu = uint32(readNPY(fullfile(basepath,KSdir,'spike_clusters.npy')));
clu = clu(auxiliarC);
pcFeatures = readNPY(fullfile(basepath,KSdir,'pc_features.npy'));
pcFeatures = pcFeatures(auxiliarC,:,:);
% pcFeatureInds = uint32(readNPY('pc_feature_ind.npy'))';
% templates = readNPY('templates.npy');
if kilosort2 %extracting shank information if it's kilosort2 output
cluster_info = tdfread(fullfile(basepath,KSdir,'cluster_info.tsv'));
clu_channels = cluster_info.ch;
shanks = zeros(size(clu_channels));
for s = 1:length(par.spikeGroups.groups)
temp1 = ismember(clu_channels,par.spikeGroups.groups{s});
shanks(temp1) = s;
end
cluShank = cluster_info.id; %this is just the ID of the cluster, bad naming that needs change.
else
shanks = readNPY(fullfile(basepath,KSdir,'shanks.npy'));
cluShank = readNPY(fullfile(basepath,KSdir,'cluster_ids.npy')); %just the ids of each clsuter
end
cd(basepath)
if kilosort2
folder_name = 'Phy2Clus';
else
folder_name = 'PhyClus';
end
mkdir(fullfile(savepath,folder_name))
%% assigning cluster ids to shanks
auxC = unique(clu);
templateshankassignments = zeros(size(auxC));
for idx = 1:length(auxC)
temp = find(cluShank == auxC(idx));
templateshankassignments(idx) = shanks(temp);
end
grouplookup = rez.ops.kcoords;
allgroups = unique(grouplookup);
allgroups(allgroups==0) = [];
for groupidx = 1:length(allgroups)
%if isfield(par.SpkGrps(groupidx),'Channels')
%if ~isempty(par.SpkGrps(groupidx).Channels)
% for each group loop through, find all templates clus
tgroup = allgroups(groupidx);%shank number
ttemplateidxs = find(templateshankassignments==tgroup);%which templates/clusters are in that shank
% ttemplates = templates(:,:,ttemplateidxs);
% tPCFeatureInds = pcFeatureInds(:,ttemplateidxs);
tidx = ismember(clu,auxC(ttemplateidxs));%find spikes indices in this shank
tclu = clu(tidx);%extract template/cluster assignments of spikes on this shank
tspktimes = spktimes(tidx);
gidx = find(rez.ops.kcoords == tgroup);%find all channels in this group
channellist = [];
for ch = 1:length(par.spikeGroups.groups{groupidx})
if sum(ismember(gidx,par.spikeGroups.groups{groupidx}(:)+1))
channellist = par.spikeGroups.groups{groupidx}(:)+1;
break
end
end
if isempty(channellist)
disp(['Cannot find spkgroup for group ' num2str(groupidx) ])
continue
end
%% spike extraction from dat
if groupidx == 1;
dat = memmapfile(datpath,'Format','int16');
end
tsampsperwave = (sbefore+safter);
ngroupchans = length(channellist);
valsperwave = tsampsperwave * ngroupchans;
wvforms_all = zeros(length(tspktimes)*tsampsperwave*ngroupchans,1,'int16');
wvranges = zeros(length(tspktimes),ngroupchans);
wvpowers = zeros(1,length(tspktimes));
for j=1:length(tspktimes)
try
w = dat.data((double(tspktimes(j))-sbefore).*totalch+1:(double(tspktimes(j))+safter).*totalch);
wvforms=reshape(w,totalch,[]);
%select needed channels
wvforms = wvforms(channellist,:);
% % detrend
% wvforms = floor(detrend(double(wvforms)));
% median subtract
wvforms = wvforms - repmat(median(wvforms')',1,sbefore+safter);
wvforms = wvforms(:);
catch
disp(['Error extracting spike at sample ' int2str(double(tspktimes(j))) '. Saving as zeros']);
disp(['Time range of that spike was: ' num2str(double(tspktimes(j))-sbefore) ' to ' num2str(double(tspktimes(j))+safter) ' samples'])
wvforms = zeros(valsperwave,1);
end
%some processing for fet file
wvaswv = reshape(wvforms,tsampsperwave,ngroupchans);
wvranges(j,:) = range(wvaswv);
wvpowers(j) = sum(sum(wvaswv.^2));
lastpoint = tsampsperwave*ngroupchans*(j-1);
wvforms_all(lastpoint+1 : lastpoint+valsperwave) = wvforms;
% wvforms_all(j,:,:)=int16(floor(detrend(double(wvforms)')));
if rem(j,100000) == 0
disp([num2str(j) ' out of ' num2str(length(tspktimes)) ' done'])
end
end
wvranges = wvranges';
%% Spike features
% for each template, rearrange the channels to reflect the shank order
tdx = [];
for tn = 1:size(pcFeatures,3)
% % tTempPCOrder = tPCFeatureInds(:,tn);%channel sequence used for pc storage for this template
for k = 1:length(channellist);
% % i = find(tTempPCOrder==channellist(k));
if ~isempty(k)
tdx(tn,k) = k;
else
tdx(tn,k) = nan;
end
%
end
end
featuresperspike = 3; % kilosort default
% initialize fet file
fets = zeros(sum(tidx),size(pcFeatures,2),size(pcFeatures,3));
pct = pcFeatures(tidx,:,:);
%for each cluster/template id, grab at once all spikes in that group
%and rearrange their features to match the shank order
allshankclu = unique(tclu);
for tc = 1:length(allshankclu)
tsc = allshankclu(tc);
i = find(tclu==tsc);
tforig = pct(i,:,:);%the subset of spikes with this clu ide
tfnew = tforig; %will overwrite
ii = tdx(tc,:);%handling nan cases where the template channel used was not in the shank
gixs = ~isnan(ii);%good vs bad channels... those shank channels that were vs were not found in template pc channels
bixs = isnan(ii);
g = ii(gixs);
tfnew = tforig;
%commented by EFO 02/12/2020
% tfnew(:,:,gixs) = tforig(:,:,gixs);%replace ok elements
%
% tfnew(:,:,bixs) = 0;%zero out channels that are not on this shank
try
fets(i,:,:) = tfnew;
% fets(i,:,1:length(par.spikeGroups.groups{groupidx})) =
% tfnew(:,:,1:length(par.spikeGroups.groups{groupidx}));
% %commented by EFO 02/12/2020
catch
keyboard
end
end
%extract for relevant spikes only...
% and heurstically on d3 only take fets for one channel for each original channel in shank... even though kilosort pulls 12 channels of fet data regardless
tfet1 = squeeze(fets(:,1,1:size(pct,3)));%lazy reshaping
tfet2 = squeeze(fets(:,2,1:size(pct,3)));
tfet3 = squeeze(fets(:,3,1:size(pct,3)));
fets = cat(2,tfet1,tfet2,tfet3)';% fets = h5read(tkwx,['/channel_groups/' num2str(shank) '/features_masks']);
% fets = double(squeeze(fets(1,:,:)));
%mean activity per spike
% fetmeans = mean(fets,1);%this is pretty redundant with wvpowers
% %find first pcs, make means of those...
% featuresperspike = 4;
% firstpcslist = 1:featuresperspike:size(fets,1);
% firstpcmeans = mean(fets(firstpcslist,:),1);
%
% nfets = size(fets,1)+1;
% fets = cat(1,fets,fetmeans,firstpcmeans,wvpowers,wvranges,double(tspktimes'));
fets = cat(1,double(fets),double(wvpowers),double(wvranges),double(tspktimes'));
fets = fets';
% fets = cat(1,nfets,fets);
%% writing to clu, res, fet, spk
cluname = fullfile(savepath, [basename '.clu.' num2str(tgroup)]);
resname = fullfile(savepath, [basename '.res.' num2str(tgroup)]);
fetname = fullfile(savepath, [basename '.fet.' num2str(tgroup)]);
spkname = fullfile(savepath, [basename '.spk.' num2str(tgroup)]);
%fet
SaveFetIn(fetname,fets);
%clu
% if ~exist(cluname,'file')
tclu = cat(1,length(unique(tclu)),double(tclu));
fid=fopen(cluname,'w');
% fprintf(fid,'%d\n',clu);
fprintf(fid,'%.0f\n',tclu);
fclose(fid);
clear fid
% end
%res
fid=fopen(resname,'w');
fprintf(fid,'%.0f\n',tspktimes);
fclose(fid);
clear fid
%spk
fid=fopen(spkname,'w');
fwrite(fid,wvforms_all,'int16');
fclose(fid);
clear fid
disp(['Shank ' num2str(tgroup) ' done'])
%end
%end
end
clear dat
copyfile(fullfile(savepath, [basename,'.clu.*']),fullfile(savepath, folder_name))
function SaveFetIn(FileName, Fet, BufSize);
if nargin<3 | isempty(BufSize)
BufSize = inf;
end
nFeatures = size(Fet, 2);
formatstring = '%d';
for ii=2:nFeatures
formatstring = [formatstring,'\t%d'];
end
formatstring = [formatstring,'\n'];
outputfile = fopen(FileName,'w');
fprintf(outputfile, '%d\n', nFeatures);
if isinf(BufSize)
temp = [round(100* Fet(:,1:end-1)) round(Fet(:,end))];
fprintf(outputfile,formatstring,temp');
else
nBuf = floor(size(Fet,1)/BufSize)+1;
for i=1:nBuf
BufInd = [(i-1)*nBuf+1:min(i*nBuf,size(Fet,1))];
temp = [round(100* Fet(BufInd,1:end-1)) round(Fet(BufInd,end))];
fprintf(outputfile,formatstring,temp');
end
end
fclose(outputfile);