@@ -105,6 +105,8 @@ def write(self, blocks):
105105 Write a list of Blocks to SONATA HDF5 files.
106106
107107 """
108+ if not os .path .isdir (self .base_dir ):
109+ os .makedirs (self .base_dir )
108110 # Write spikes
109111 spike_file_path = join (self .base_dir , self .spike_file )
110112 spikes_file = h5py .File (spike_file_path , 'w' )
@@ -131,36 +133,38 @@ def write(self, blocks):
131133 file_path = join (self .base_dir , file_name )
132134
133135 signal_file = h5py .File (file_path , 'w' )
134- population_name = self .node_sets [report_metadata ["cells" ]]["population" ]
135- node_ids = self .node_sets [report_metadata ["cells" ]]["node_id" ]
136+ targets = self .node_sets [report_metadata ["cells" ]]
136137 for block in blocks :
137- if block .name == population_name :
138- if len (block .segments ) > 1 :
139- raise NotImplementedError ()
140- signal = block .segments [0 ].filter (name = report_metadata ["variable_name" ])
141- if len (signal ) != 1 :
142- raise NotImplementedError ()
143-
144- report_group = signal_file .create_group ("report" )
145- population_group = report_group .create_group (population_name )
146- dataset = population_group .create_dataset ("data" , data = signal [0 ].magnitude )
147- dataset .attrs ["units" ] = signal [0 ].units .dimensionality .string
148- dataset .attrs ["variable_name" ] = report_metadata ["variable_name" ]
149- n = dataset .shape [1 ]
150- mapping_group = population_group .create_group ("mapping" )
151- mapping_group .create_dataset ("node_ids" , data = node_ids )
152- # "gids" not in the spec, but expected by some bmtk utils
153- mapping_group .create_dataset ("gids" , data = node_ids )
154- #mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
155- mapping_group .create_dataset ("index_pointer" , data = np .arange (0 , n + 1 )) # ??spec unclear
156- mapping_group .create_dataset ("element_ids" , data = np .zeros ((n ,)))
157- mapping_group .create_dataset ("element_pos" , data = np .zeros ((n ,)))
158- time_ds = mapping_group .create_dataset ("time" ,
159- data = (float (signal [0 ].t_start ),
160- float (signal [0 ].t_stop ),
161- float (signal [0 ].sampling_period )))
162- time_ds .attrs ["units" ] = "ms"
163- logger .info ("Wrote block {} to {}" .format (block .name , file_path ))
138+ for (assembly , mask ) in targets :
139+ if block .name == assembly .label :
140+ if len (block .segments ) > 1 :
141+ raise NotImplementedError ()
142+ signal = block .segments [0 ].filter (name = report_metadata ["variable_name" ])
143+ if len (signal ) != 1 :
144+ raise NotImplementedError ()
145+
146+ node_ids = np .arange (assembly .size )[mask ]
147+
148+ report_group = signal_file .create_group ("report" )
149+ population_group = report_group .create_group (assembly .label )
150+ dataset = population_group .create_dataset ("data" , data = signal [0 ].magnitude )
151+ dataset .attrs ["units" ] = signal [0 ].units .dimensionality .string
152+ dataset .attrs ["variable_name" ] = report_metadata ["variable_name" ]
153+ n = dataset .shape [1 ]
154+ mapping_group = population_group .create_group ("mapping" )
155+ mapping_group .create_dataset ("node_ids" , data = node_ids )
156+ # "gids" not in the spec, but expected by some bmtk utils
157+ mapping_group .create_dataset ("gids" , data = node_ids )
158+ #mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
159+ mapping_group .create_dataset ("index_pointer" , data = np .arange (0 , n + 1 )) # ??spec unclear
160+ mapping_group .create_dataset ("element_ids" , data = np .zeros ((n ,)))
161+ mapping_group .create_dataset ("element_pos" , data = np .zeros ((n ,)))
162+ time_ds = mapping_group .create_dataset ("time" ,
163+ data = (float (signal [0 ].t_start .rescale ('ms' )),
164+ float (signal [0 ].t_stop .rescale ('ms' )),
165+ float (signal [0 ].sampling_period .rescale ('ms' ))))
166+ time_ds .attrs ["units" ] = "ms"
167+ logger .info ("Wrote block {} to {}" .format (block .name , file_path ))
164168 signal_file .close ()
165169
166170
@@ -232,6 +236,7 @@ def condense(value, types_array):
232236 from "/edges/<population_name>/edge_type_id" that applies to this group.
233237 Needed to construct parameter arrays.
234238 """
239+ # todo: use lazyarray
235240 if isinstance (value , np .ndarray ):
236241 return value
237242 elif isinstance (value , dict ):
@@ -240,7 +245,12 @@ def condense(value, types_array):
240245 if np .all (value_array == value_array [0 ]):
241246 return value_array [0 ]
242247 else :
243- new_value = np .ones_like (types_array ) * np .nan
248+ if np .issubdtype (value_array .dtype , np .number ):
249+ new_value = np .ones_like (types_array ) * np .nan
250+ elif np .issubdtype (value_array .dtype , np .str_ ):
251+ new_value = np .array (["UNDEFINED" ] * types_array .size )
252+ else :
253+ raise TypeError ("Cannot handle annotations that are neither numbers or strings" )
244254 for node_type_id , val in value .items ():
245255 new_value [types_array == node_type_id ] = val
246256 return new_value
@@ -584,10 +594,10 @@ def import_from_sonata(config_file, sim):
584594 net = Network ()
585595 for node_population in sonata_node_populations :
586596 assembly = node_population .to_assembly (sim )
587- net .assemblies . add (assembly )
597+ net .add (assembly )
588598 for edge_population in sonata_edge_populations :
589599 projections = edge_population .to_projections (net , sim )
590- net .projections . update ( projections )
600+ net .add ( * projections )
591601
592602 return net
593603
@@ -777,7 +787,7 @@ def to_population(self, sim):
777787 if name in cell_type_cls .default_parameters :
778788 parameters [name ] = condense (value , self .node_types_array )
779789 else :
780- annotations [name ] = value
790+ annotations [name ] = condense ( value , self . node_types_array )
781791 # todo: handle spatial structure - nodes_file["nodes"][np_label][ng_label]['x'], etc.
782792
783793 # temporary hack to work around problem with 300 Intfire cell example
@@ -1072,28 +1082,21 @@ def setup(self, sim):
10721082 self .sim = sim
10731083 sim .setup (timestep = self .run_config ["dt" ])
10741084
1075- def _get_target (self , config , node_sets , net ):
1085+ def _get_target (self , config , net ):
10761086 if "node_set" in config : # input config
1077- target = node_sets [config ["node_set" ]]
1078- elif "cells" in config : # recording config
1087+ targets = self . node_set_map [config ["node_set" ]]
1088+ elif "cells" in config : # recording config
10791089 # inconsistency in SONATA spec? Why not call this "node_set" also?
1080- target = node_sets [config ["cells" ]]
1081- if "model_type" in target :
1082- raise NotImplementedError ()
1083- if "location" in target :
1084- raise NotImplementedError ()
1085- if "gids" in target :
1086- raise NotImplementedError ()
1087- if "population" in target :
1088- assembly = net .get_component (target ["population" ])
1089- if "node_id" in target :
1090- indices = target ["node_id" ]
1091- assembly = assembly [indices ]
1092- return assembly
1090+ targets = self .node_set_map [config ["cells" ]]
1091+ return targets
10931092
1094- def _set_input_spikes (self , input_config , node_sets , net ):
1093+ def _set_input_spikes (self , input_config , net ):
10951094 # determine which assembly the spikes are for
1096- assembly = self ._get_target (input_config , node_sets , net )
1095+ targets = self ._get_target (input_config , net )
1096+ if len (targets ) != 1 :
1097+ raise NotImplementedError ()
1098+ base_assembly , mask = targets [0 ]
1099+ assembly = base_assembly [mask ]
10971100 assert isinstance (assembly , self .sim .Assembly )
10981101
10991102 # load spike data from file
@@ -1111,22 +1114,88 @@ def _set_input_spikes(self, input_config, node_sets, net):
11111114 if len (spiketrains ) != assembly .size :
11121115 raise NotImplementedError ()
11131116 # todo: map cell ids in spikes file to ids/index in the population
1114- #logger.info("SETTING SPIKETIMES")
1115- #logger.info(spiketrains)
11161117 assembly .set (spike_times = [Sequence (st .times .rescale ('ms' ).magnitude ) for st in spiketrains ])
11171118
1119+ def _set_input_currents (self , input_config , net ):
1120+ # determine which assembly the currents are for
1121+ if "input_file" in input_config :
1122+ raise NotImplementedError ("Current clamp from source file not yet supported." )
1123+ targets = self ._get_target (input_config , net )
1124+ if len (targets ) != 1 :
1125+ raise NotImplementedError ()
1126+ base_assembly , mask = targets [0 ]
1127+ assembly = base_assembly [mask ]
1128+ assert isinstance (assembly , self .sim .Assembly )
1129+ amplitude = input_config ["amp" ] # nA
1130+ if self .target_simulator == "NEST" :
1131+ amplitude = input_config ["amp" ]/ 1000.0 # pA
1132+
1133+ current_source = self .sim .DCSource (amplitude = amplitude ,
1134+ start = input_config ["delay" ],
1135+ stop = input_config ["delay" ] + input_config ["duration" ])
1136+ assembly .inject (current_source )
1137+
1138+ def _calculate_node_set_map (self , net ):
1139+ # for each "node set" in the config, determine which populations
1140+ # and node_ids it corresponds to
1141+ self .node_set_map = {}
1142+
1143+ # first handle implicit node sets - i.e. each node population is an implicit node set
1144+ for assembly in net .assemblies :
1145+ self .node_set_map [assembly .label ] = [(assembly , slice (None ))]
1146+
1147+ # now handle explictly-declared node sets
1148+ # todo: handle compound node sets
1149+ for node_set_name , node_set_definition in self .node_sets .items ():
1150+ if isinstance (node_set_definition , dict ): # basic node set
1151+ filters = node_set_definition
1152+ if "population" in filters :
1153+ assemblies = [net .get_component (filters ["population" ])]
1154+ else :
1155+ assemblies = list (net .assemblies )
1156+
1157+ self .node_set_map [node_set_name ] = []
1158+ for assembly in assemblies :
1159+ mask = True
1160+ for attr_name , attr_value in filters .items ():
1161+ print (attr_name , attr_value , "____" )
1162+ if attr_name == "population" :
1163+ continue
1164+ elif attr_name == "node_id" :
1165+ # convert integer mask to boolean mask
1166+ node_mask = np .zeros (assembly .size , dtype = bool )
1167+ node_mask [attr_value ] = True
1168+ mask = np .logical_and (mask , node_mask )
1169+ else :
1170+ values = assembly .get_annotations (attr_name )[attr_name ]
1171+ mask = np .logical_and (mask , values == attr_value )
1172+ if isinstance (mask , (bool , np .bool_ )) and mask == True :
1173+ mask = slice (None )
1174+ self .node_set_map [node_set_name ].append ((assembly , mask ))
1175+ elif isinstance (node_set_definition , list ): # compound node set
1176+ raise NotImplementedError ("Compound node sets not yet supported" )
1177+ else :
1178+ raise TypeError ("Expecting node set definition to be a list or dict" )
1179+
11181180 def execute (self , net ):
1181+ self ._calculate_node_set_map (net )
1182+
11191183 # create/configure inputs
11201184 for input_name , input_config in self .inputs .items ():
1121- if input_config ["input_type" ] != "spikes" :
1122- raise NotImplementedError ()
1123- self ._set_input_spikes (input_config , self .node_sets , net )
1185+ if input_config ["input_type" ] == "spikes" :
1186+ self ._set_input_spikes (input_config , net )
1187+ elif input_config ["input_type" ] == "current_clamp" :
1188+ self ._set_input_currents (input_config , net )
1189+ else :
1190+ raise NotImplementedError ("Only 'spikes' and 'current_clamp' supported" )
11241191
11251192 # configure recording
11261193 net .record ('spikes' , include_spike_source = False ) # SONATA requires that we record spikes from all non-virtual nodes
11271194 for report_name , report_config in self .reports .items ():
1128- assembly = self ._get_target (report_config , self .node_sets , net )
1129- assembly .record (report_config ["variable_name" ])
1195+ targets = self ._get_target (report_config , net )
1196+ for (base_assembly , mask ) in targets :
1197+ assembly = base_assembly [mask ]
1198+ assembly .record (report_config ["variable_name" ])
11301199
11311200 # run simulation
11321201 self .sim .run (self .run_config ["tstop" ])
@@ -1141,7 +1210,7 @@ def execute(self, net):
11411210 spikes_file = self .output .get ("spikes_file" , "spikes.h5" ),
11421211 spikes_sort_order = self .output ["spikes_sort_order" ],
11431212 report_config = self .reports ,
1144- node_sets = self .node_sets )
1213+ node_sets = self .node_set_map )
11451214 # todo: handle reports
11461215 net .write_data (io )
11471216
0 commit comments