-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathloading_utils.py
More file actions
129 lines (92 loc) · 5.22 KB
/
loading_utils.py
File metadata and controls
129 lines (92 loc) · 5.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import numpy as np
import pandas as pd
import pickle
import processing_utils as pu
DT = 0.005 # [s]
TAU = 0.01 # [s]
QUERY = "group=='good'"
def load_all_raw_data(dataset_names=[], filter_query=None, random_sample=None) -> pd.DataFrame:
dfs =[]
event_times_all = {}
for dataset_name in dataset_names:
try:
PATH = os.path.join('../../rawdata/', dataset_name)
print(f'>>> Process data from {PATH}')
spike_clusters = np.load(os.path.join(PATH, 'spike_clusters.npy'))
spike_times = np.load(os.path.join(PATH, 'spike_times.npy'))
region_map = np.loadtxt(os.path.join(PATH, 'region_map.tsv'), delimiter='\t', dtype=str)
cluster_info = pd.read_csv(os.path.join(PATH, 'cluster_info.tsv'), sep='\t')
channel_map = np.loadtxt(os.path.join(PATH, 'electrode_map.tsv'), delimiter='\t', dtype=int)
event_times = np.loadtxt(os.path.join(PATH, 'event_times.tsv'), delimiter='\t', dtype=float)
event_times = event_times.T # Transpose to have time in first column
event_times[0] = event_times[0] / 30E3 # Convert to seconds
# Align region_map with channel_map
region_for_all_ids = []
for ch in cluster_info['ch']:
region_for_id = region_map[np.where(channel_map == ch)]
region_for_all_ids.append(region_for_id[0])
cluster_info['region'] = region_for_all_ids
# Unify cluster_id column
id_key = [col for col in cluster_info.columns if 'id' in col]
assert len(id_key)==1, f"No unique id found in {os.path.join(PATH, 'cluster_info.tsv')}. Instead: {id_key}"
if id_key[0] != 'cluster_id':
cluster_info['cluster_id'] = cluster_info[id_key[0]]
cluster_info.drop(labels=[id_key[0]], inplace=True, axis=1)
# Align spike_times
spike_times_for_all_ids = []
for id in cluster_info['cluster_id']:
spike_times_for_id = spike_times[np.where(spike_clusters == id)]
spike_times_for_all_ids.append(spike_times_for_id.squeeze())
cluster_info['spike_times'] = spike_times_for_all_ids
cluster_info['spike_times'] = cluster_info['spike_times']/30E3
cluster_info['dataset'] = os.path.basename(dataset_name)
dfs.append(cluster_info)
event_times_all[os.path.basename(dataset_name)] = event_times
except Exception as e:
print(f"Error processing dataset {dataset_name}: {e}")
continue
df = pd.concat(dfs, ignore_index=True)
# Give each unit clear ID - dataset/unitID
df['cluster_id'] = df['dataset'] + '_' + df['cluster_id'].apply(lambda x: f"{x:04d}")
# Compute Contamination Rate
# df['ISI_violations'] = [contamination_rate(spkt) for spkt in df['spike_times']]
# Apply filters
if filter_query is not None:
df = df.query(filter_query)
if random_sample is not None:
df = df.sample(n=random_sample, random_state=42)
df.reset_index(drop=True, inplace=True)
# Load event codes
event_codes = np.loadtxt(os.path.join('../../rawdata/2023', 'event_codes.tsv'), delimiter='\t', dtype=str, skiprows=1,)
event_codes = {str(e[2]): int(e[0]) for e in event_codes[:8]}
return df, event_times_all, event_codes
def load_data_2023(dt=DT, tau=TAU, query=QUERY):
"""
Load and preprocess data for 2023.
dt: time bin size for spike density computation
tau: time constant for spike density smoothing
query: filter query for the dataframe
"""
RAWDATA_DIR = '../../rawdata/2023'
DERIVIATIVES_DIR = os.path.join(RAWDATA_DIR, 'derivatives')
if f'2023_preprocessed_df_dt-{dt}_tau-{tau}_query-{query}.pkl' in os.listdir(DERIVIATIVES_DIR):
print(f">>> Loading preprocessed data from {DERIVIATIVES_DIR}")
df = pd.read_pickle(os.path.join(DERIVIATIVES_DIR, f'2023_preprocessed_df_dt-{dt}_tau-{tau}_query-{query}.pkl'))
with open(os.path.join(DERIVIATIVES_DIR, '2023_preprocessed_event_times.pkl'), 'rb') as f:
event_times = pickle.load(f)
event_codes = np.loadtxt(os.path.join(RAWDATA_DIR, 'event_codes.tsv'), delimiter='\t', dtype=str, skiprows=1,)
event_codes = {str(e[2]): int(e[0]) for e in event_codes[:8]}
else:
print(f">>> Processing raw data from {RAWDATA_DIR}")
paths = [os.path.join('2023', d) for d in os.listdir(RAWDATA_DIR) if os.path.isdir(os.path.join(RAWDATA_DIR, d))]
if '2023/derivatives' in paths:
paths.remove('2023/derivatives')
paths.sort(key=lambda x: int(x.split('-')[1]) if '-' in x else 0)
df, event_times, event_codes = load_all_raw_data(paths, filter_query=query)
df['spike_density'] = df.apply(pu.compute_spike_density, axis=1, bin_size=dt, tau=tau)
print(f">>> Saving preprocessed data to {DERIVIATIVES_DIR}")
df.to_pickle(os.path.join(DERIVIATIVES_DIR, f'2023_preprocessed_df_dt-{dt}_tau-{tau}_query-{query}.pkl'))
with open(os.path.join(DERIVIATIVES_DIR, '2023_preprocessed_event_times.pkl'), 'wb') as f:
pickle.dump(event_times, f)
return df, event_times, event_codes