Skip to content

Commit f740854

Browse files
authored
Merge pull request #638 from NeuralEnsemble/sonata
Merge SONATA branch
2 parents ed1cddd + 6d1e062 commit f740854

3 files changed

Lines changed: 179 additions & 65 deletions

File tree

pyNN/common/populations.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,3 +1470,23 @@ def describe(self, template='assembly_default.txt', engine='default'):
14701470
context = {"label": self.label,
14711471
"populations": [p.describe(template=None) for p in self.populations]}
14721472
return descriptions.render(engine, template, context)
1473+
1474+
def get_annotations(self, annotation_keys, simplify=True):
1475+
"""
1476+
Get the values of the given annotations for each population in the Assembly.
1477+
"""
1478+
if isinstance(annotation_keys, basestring):
1479+
annotation_keys = (annotation_keys,)
1480+
annotations = defaultdict(list)
1481+
1482+
for key in annotation_keys:
1483+
is_array_annotation = False
1484+
for p in self.populations:
1485+
annotation = p.annotations[key]
1486+
annotations[key].append(annotation)
1487+
is_array_annotation = isinstance(annotation, numpy.ndarray)
1488+
if is_array_annotation:
1489+
annotations[key] = numpy.hstack(annotations[key])
1490+
if simplify:
1491+
annotations[key] = simplify_parameter_array(numpy.array(annotations[key]))
1492+
return annotations

pyNN/network.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
44
"""
55

6+
import sys
7+
import inspect
68
from itertools import chain
79
try:
810
basestring
@@ -22,6 +24,23 @@ def __init__(self, *components):
2224
self.views = set([])
2325
self.assemblies = set([])
2426
self.projections = set([])
27+
self.add(*components)
28+
29+
@property
30+
def sim(self):
31+
"""Figure out which PyNN backend module this Network is using."""
32+
# we assume there is only one. Could be mixed if using multiple simulators
33+
# at once.
34+
populations_module = inspect.getmodule(list(self.populations)[0].__class__)
35+
return sys.modules[".".join(populations_module.__name__.split(".")[:-1])]
36+
37+
def count_neurons(self):
38+
return sum(population.size for population in chain(self.populations))
39+
40+
def count_connections(self):
41+
return sum(projection.size() for projection in chain(self.projections))
42+
43+
def add(self, *components):
2544
for component in components:
2645
if isinstance(component, Population):
2746
self.populations.add(component)
@@ -37,18 +56,24 @@ def __init__(self, *components):
3756
else:
3857
raise TypeError()
3958

40-
def count_neurons(self):
41-
return sum(population.size for population in chain(self.populations))
42-
43-
def count_connections(self):
44-
return sum(projection.size() for projection in chain(self.projections))
45-
4659
def get_component(self, label):
4760
for obj in chain(self.populations, self.views, self.assemblies, self.projections):
4861
if obj.label == label:
4962
return obj
5063
return None
5164

65+
def filter(self, cell_types=None):
66+
"""Return an Assembly of all components that have a cell type in the list"""
67+
if cell_types is None:
68+
raise NotImplementedError()
69+
else:
70+
if cell_types == "all":
71+
return self.sim.Assembly(*(pop for pop in self.populations
72+
if pop.celltype.injectable)) # or could use len(receptor_types) > 0
73+
else:
74+
return self.sim.Assembly(*(pop for pop in self.populations
75+
if pop.celltype.__class__ in cell_types))
76+
5277
def record(self, variables, to_file=None, sampling_interval=None, include_spike_source=True):
5378
for obj in chain(self.populations, self.assemblies):
5479
if include_spike_source or obj.injectable: # spike sources are not injectable

pyNN/serialization/sonata.py

Lines changed: 128 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def write(self, blocks):
105105
Write a list of Blocks to SONATA HDF5 files.
106106
107107
"""
108+
if not os.path.isdir(self.base_dir):
109+
os.makedirs(self.base_dir)
108110
# Write spikes
109111
spike_file_path = join(self.base_dir, self.spike_file)
110112
spikes_file = h5py.File(spike_file_path, 'w')
@@ -131,36 +133,38 @@ def write(self, blocks):
131133
file_path = join(self.base_dir, file_name)
132134

133135
signal_file = h5py.File(file_path, 'w')
134-
population_name = self.node_sets[report_metadata["cells"]]["population"]
135-
node_ids = self.node_sets[report_metadata["cells"]]["node_id"]
136+
targets = self.node_sets[report_metadata["cells"]]
136137
for block in blocks:
137-
if block.name == population_name:
138-
if len(block.segments) > 1:
139-
raise NotImplementedError()
140-
signal = block.segments[0].filter(name=report_metadata["variable_name"])
141-
if len(signal) != 1:
142-
raise NotImplementedError()
143-
144-
report_group = signal_file.create_group("report")
145-
population_group = report_group.create_group(population_name)
146-
dataset = population_group.create_dataset("data", data=signal[0].magnitude)
147-
dataset.attrs["units"] = signal[0].units.dimensionality.string
148-
dataset.attrs["variable_name"] = report_metadata["variable_name"]
149-
n = dataset.shape[1]
150-
mapping_group = population_group.create_group("mapping")
151-
mapping_group.create_dataset("node_ids", data=node_ids)
152-
# "gids" not in the spec, but expected by some bmtk utils
153-
mapping_group.create_dataset("gids", data=node_ids)
154-
#mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
155-
mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear
156-
mapping_group.create_dataset("element_ids", data=np.zeros((n,)))
157-
mapping_group.create_dataset("element_pos", data=np.zeros((n,)))
158-
time_ds = mapping_group.create_dataset("time",
159-
data=(float(signal[0].t_start),
160-
float(signal[0].t_stop),
161-
float(signal[0].sampling_period)))
162-
time_ds.attrs["units"] = "ms"
163-
logger.info("Wrote block {} to {}".format(block.name, file_path))
138+
for (assembly, mask) in targets:
139+
if block.name == assembly.label:
140+
if len(block.segments) > 1:
141+
raise NotImplementedError()
142+
signal = block.segments[0].filter(name=report_metadata["variable_name"])
143+
if len(signal) != 1:
144+
raise NotImplementedError()
145+
146+
node_ids = np.arange(assembly.size)[mask]
147+
148+
report_group = signal_file.create_group("report")
149+
population_group = report_group.create_group(assembly.label)
150+
dataset = population_group.create_dataset("data", data=signal[0].magnitude)
151+
dataset.attrs["units"] = signal[0].units.dimensionality.string
152+
dataset.attrs["variable_name"] = report_metadata["variable_name"]
153+
n = dataset.shape[1]
154+
mapping_group = population_group.create_group("mapping")
155+
mapping_group.create_dataset("node_ids", data=node_ids)
156+
# "gids" not in the spec, but expected by some bmtk utils
157+
mapping_group.create_dataset("gids", data=node_ids)
158+
#mapping_group.create_dataset("index_pointers", data=np.zeros((n,)))
159+
mapping_group.create_dataset("index_pointer", data=np.arange(0, n+1)) # ??spec unclear
160+
mapping_group.create_dataset("element_ids", data=np.zeros((n,)))
161+
mapping_group.create_dataset("element_pos", data=np.zeros((n,)))
162+
time_ds = mapping_group.create_dataset("time",
163+
data=(float(signal[0].t_start.rescale('ms')),
164+
float(signal[0].t_stop.rescale('ms')),
165+
float(signal[0].sampling_period.rescale('ms'))))
166+
time_ds.attrs["units"] = "ms"
167+
logger.info("Wrote block {} to {}".format(block.name, file_path))
164168
signal_file.close()
165169

166170

@@ -232,6 +236,7 @@ def condense(value, types_array):
232236
from "/edges/<population_name>/edge_type_id" that applies to this group.
233237
Needed to construct parameter arrays.
234238
"""
239+
# todo: use lazyarray
235240
if isinstance(value, np.ndarray):
236241
return value
237242
elif isinstance(value, dict):
@@ -240,7 +245,12 @@ def condense(value, types_array):
240245
if np.all(value_array == value_array[0]):
241246
return value_array[0]
242247
else:
243-
new_value = np.ones_like(types_array) * np.nan
248+
if np.issubdtype(value_array.dtype, np.number):
249+
new_value = np.ones_like(types_array) * np.nan
250+
elif np.issubdtype(value_array.dtype, np.str_):
251+
new_value = np.array(["UNDEFINED"] * types_array.size)
252+
else:
253+
raise TypeError("Cannot handle annotations that are neither numbers or strings")
244254
for node_type_id, val in value.items():
245255
new_value[types_array == node_type_id] = val
246256
return new_value
@@ -584,10 +594,10 @@ def import_from_sonata(config_file, sim):
584594
net = Network()
585595
for node_population in sonata_node_populations:
586596
assembly = node_population.to_assembly(sim)
587-
net.assemblies.add(assembly)
597+
net.add(assembly)
588598
for edge_population in sonata_edge_populations:
589599
projections = edge_population.to_projections(net, sim)
590-
net.projections.update(projections)
600+
net.add(*projections)
591601

592602
return net
593603

@@ -777,7 +787,7 @@ def to_population(self, sim):
777787
if name in cell_type_cls.default_parameters:
778788
parameters[name] = condense(value, self.node_types_array)
779789
else:
780-
annotations[name] = value
790+
annotations[name] = condense(value, self.node_types_array)
781791
# todo: handle spatial structure - nodes_file["nodes"][np_label][ng_label]['x'], etc.
782792

783793
# temporary hack to work around problem with 300 Intfire cell example
@@ -1072,28 +1082,21 @@ def setup(self, sim):
10721082
self.sim = sim
10731083
sim.setup(timestep=self.run_config["dt"])
10741084

1075-
def _get_target(self, config, node_sets, net):
1085+
def _get_target(self, config, net):
10761086
if "node_set" in config: # input config
1077-
target = node_sets[config["node_set"]]
1078-
elif "cells" in config: # recording config
1087+
targets = self.node_set_map[config["node_set"]]
1088+
elif "cells" in config: # recording config
10791089
# inconsistency in SONATA spec? Why not call this "node_set" also?
1080-
target = node_sets[config["cells"]]
1081-
if "model_type" in target:
1082-
raise NotImplementedError()
1083-
if "location" in target:
1084-
raise NotImplementedError()
1085-
if "gids" in target:
1086-
raise NotImplementedError()
1087-
if "population" in target:
1088-
assembly = net.get_component(target["population"])
1089-
if "node_id" in target:
1090-
indices = target["node_id"]
1091-
assembly = assembly[indices]
1092-
return assembly
1090+
targets = self.node_set_map[config["cells"]]
1091+
return targets
10931092

1094-
def _set_input_spikes(self, input_config, node_sets, net):
1093+
def _set_input_spikes(self, input_config, net):
10951094
# determine which assembly the spikes are for
1096-
assembly = self._get_target(input_config, node_sets, net)
1095+
targets = self._get_target(input_config, net)
1096+
if len(targets) != 1:
1097+
raise NotImplementedError()
1098+
base_assembly, mask = targets[0]
1099+
assembly = base_assembly[mask]
10971100
assert isinstance(assembly, self.sim.Assembly)
10981101

10991102
# load spike data from file
@@ -1111,22 +1114,88 @@ def _set_input_spikes(self, input_config, node_sets, net):
11111114
if len(spiketrains) != assembly.size:
11121115
raise NotImplementedError()
11131116
# todo: map cell ids in spikes file to ids/index in the population
1114-
#logger.info("SETTING SPIKETIMES")
1115-
#logger.info(spiketrains)
11161117
assembly.set(spike_times=[Sequence(st.times.rescale('ms').magnitude) for st in spiketrains])
11171118

1119+
def _set_input_currents(self, input_config, net):
1120+
# determine which assembly the currents are for
1121+
if "input_file" in input_config:
1122+
raise NotImplementedError("Current clamp from source file not yet supported.")
1123+
targets = self._get_target(input_config, net)
1124+
if len(targets) != 1:
1125+
raise NotImplementedError()
1126+
base_assembly, mask = targets[0]
1127+
assembly = base_assembly[mask]
1128+
assert isinstance(assembly, self.sim.Assembly)
1129+
amplitude = input_config["amp"] # nA
1130+
if self.target_simulator == "NEST":
1131+
amplitude = input_config["amp"]/1000.0 # pA
1132+
1133+
current_source = self.sim.DCSource(amplitude=amplitude,
1134+
start=input_config["delay"],
1135+
stop=input_config["delay"] + input_config["duration"])
1136+
assembly.inject(current_source)
1137+
1138+
def _calculate_node_set_map(self, net):
1139+
# for each "node set" in the config, determine which populations
1140+
# and node_ids it corresponds to
1141+
self.node_set_map = {}
1142+
1143+
# first handle implicit node sets - i.e. each node population is an implicit node set
1144+
for assembly in net.assemblies:
1145+
self.node_set_map[assembly.label] = [(assembly, slice(None))]
1146+
1147+
# now handle explictly-declared node sets
1148+
# todo: handle compound node sets
1149+
for node_set_name, node_set_definition in self.node_sets.items():
1150+
if isinstance(node_set_definition, dict): # basic node set
1151+
filters = node_set_definition
1152+
if "population" in filters:
1153+
assemblies = [net.get_component(filters["population"])]
1154+
else:
1155+
assemblies = list(net.assemblies)
1156+
1157+
self.node_set_map[node_set_name] = []
1158+
for assembly in assemblies:
1159+
mask = True
1160+
for attr_name, attr_value in filters.items():
1161+
print(attr_name, attr_value, "____")
1162+
if attr_name == "population":
1163+
continue
1164+
elif attr_name == "node_id":
1165+
# convert integer mask to boolean mask
1166+
node_mask = np.zeros(assembly.size, dtype=bool)
1167+
node_mask[attr_value] = True
1168+
mask = np.logical_and(mask, node_mask)
1169+
else:
1170+
values = assembly.get_annotations(attr_name)[attr_name]
1171+
mask = np.logical_and(mask, values == attr_value)
1172+
if isinstance(mask, (bool, np.bool_)) and mask == True:
1173+
mask = slice(None)
1174+
self.node_set_map[node_set_name].append((assembly, mask))
1175+
elif isinstance(node_set_definition, list): # compound node set
1176+
raise NotImplementedError("Compound node sets not yet supported")
1177+
else:
1178+
raise TypeError("Expecting node set definition to be a list or dict")
1179+
11181180
def execute(self, net):
1181+
self._calculate_node_set_map(net)
1182+
11191183
# create/configure inputs
11201184
for input_name, input_config in self.inputs.items():
1121-
if input_config["input_type"] != "spikes":
1122-
raise NotImplementedError()
1123-
self._set_input_spikes(input_config, self.node_sets, net)
1185+
if input_config["input_type"] == "spikes":
1186+
self._set_input_spikes(input_config, net)
1187+
elif input_config["input_type"] == "current_clamp":
1188+
self._set_input_currents(input_config, net)
1189+
else:
1190+
raise NotImplementedError("Only 'spikes' and 'current_clamp' supported")
11241191

11251192
# configure recording
11261193
net.record('spikes', include_spike_source=False) # SONATA requires that we record spikes from all non-virtual nodes
11271194
for report_name, report_config in self.reports.items():
1128-
assembly = self._get_target(report_config, self.node_sets, net)
1129-
assembly.record(report_config["variable_name"])
1195+
targets = self._get_target(report_config, net)
1196+
for (base_assembly, mask) in targets:
1197+
assembly = base_assembly[mask]
1198+
assembly.record(report_config["variable_name"])
11301199

11311200
# run simulation
11321201
self.sim.run(self.run_config["tstop"])
@@ -1141,7 +1210,7 @@ def execute(self, net):
11411210
spikes_file=self.output.get("spikes_file", "spikes.h5"),
11421211
spikes_sort_order=self.output["spikes_sort_order"],
11431212
report_config=self.reports,
1144-
node_sets=self.node_sets)
1213+
node_sets=self.node_set_map)
11451214
# todo: handle reports
11461215
net.write_data(io)
11471216

0 commit comments

Comments
 (0)