Skip to content

Commit 832f56b

Browse files
committed
feat: added filter by connection names to operations
1 parent 7d2c718 commit 832f56b

File tree

2 files changed

+222
-44
lines changed

2 files changed

+222
-44
lines changed

datashield/api.py

Lines changed: 86 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -129,23 +129,28 @@ def open(self, restore: str = None, failSafe: bool = False) -> None:
129129
for name in self.errors:
130130
logging.error(f"Connection to {name} has failed")
131131

132-
def close(self, save: str = None) -> None:
132+
def close(self, save: str = None, conn_names: list[str] = None) -> None:
133133
"""
134134
Close connections with remote servers.
135135
136-
:param cons: The list of connections to close.
137136
:param save: The name of the workspace to save before closing the connections.
137+
:param conn_names: The optional list of connection names to close. If not defined, all opened connections are closed.
138138
"""
139139
self.errors = {}
140-
for conn in self.conns:
140+
selected_conns = self._get_selected_connections(conn_names)
141+
selected_names = {conn.get_name() for conn in selected_conns}
142+
for conn in selected_conns:
141143
try:
142144
if save:
143145
conn.save_workspace(f"{conn.get_name()}:{save}")
144146
conn.disconnect()
145147
except DSError:
146148
# silently fail
147149
pass
148-
self.conns = None
150+
if conn_names is None:
151+
self.conns = None
152+
else:
153+
self.conns = [conn for conn in self.conns if conn.get_name() not in selected_names]
149154

150155
def has_connections(self) -> bool:
151156
"""
@@ -161,10 +166,7 @@ def get_connection_names(self) -> list[str]:
161166
162167
:return: The list of opened connection names
163168
"""
164-
if self.conns:
165-
return [conn.get_name() for conn in self.conns]
166-
else:
167-
return []
169+
return [conn.get_name() for conn in self.conns]
168170

169171
def has_errors(self) -> bool:
170172
"""
@@ -186,27 +188,29 @@ def get_errors(self) -> dict:
186188
# Environment
187189
#
188190

189-
def tables(self) -> dict:
191+
def tables(self, conn_names: list[str] = None) -> dict:
190192
"""
191193
List available table names from the data repository.
192194
195+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
193196
:return: The available table names from the data repository, per remote server name
194197
"""
195198
rval = {}
196-
for conn in self.conns:
199+
for conn in self._get_selected_connections(conn_names):
197200
rval[conn.get_name()] = conn.list_tables()
198201
return rval
199202

200-
def variables(self, table: str = None, tables: dict = None) -> dict:
203+
def variables(self, table: str = None, tables: dict = None, conn_names: list[str] = None) -> dict:
201204
"""
202205
List available variables from the data repository, for a given table.
203206
204207
:param table: The default name of the table to list variables for
205208
:param tables: The name of the table to list variables for, per server name. If not defined, 'table' is used.
209+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
206210
:return: The available variables from the data repository, for a given table, per remote server name
207211
"""
208212
rval = {}
209-
for conn in self.conns:
213+
for conn in self._get_selected_connections(conn_names):
210214
name = table
211215
if tables and conn.get_name() in tables:
212216
name = tables[conn.get_name()]
@@ -216,120 +220,130 @@ def variables(self, table: str = None, tables: dict = None) -> dict:
216220
rval[conn.get_name()] = None
217221
return rval
218222

219-
def taxonomies(self) -> dict:
223+
def taxonomies(self, conn_names: list[str] = None) -> dict:
220224
"""
221225
List available taxonomies from the data repository. A taxonomy is a hierarchical structure of vocabulary
222226
terms that can be used to annotate variables in the data repository.
223227
Depending on the data repository's capabilities, taxonomies can be used to perform structured
224228
queries when searching for variables.
225229
230+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
226231
:return: The available taxonomies from the data repository, per remote server name
227232
"""
228233
rval = {}
229-
for conn in self.conns:
234+
for conn in self._get_selected_connections(conn_names):
230235
rval[conn.get_name()] = conn.list_taxonomies()
231236
return rval
232237

233-
def search_variables(self, query: str) -> dict:
238+
def search_variables(self, query: str, conn_names: list[str] = None) -> dict:
234239
"""
235240
Search for variable names matching a given query across all tables in the data repository.
236241
237242
:param query: The query to search for in variable names, e.g., a full-text search and/or structured
238243
query (based on taxonomy terms), depending on the data repository's capabilities
244+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
239245
:return: The matching variable names from the data repository, per remote server name
240246
"""
241247
rval = {}
242-
for conn in self.conns:
248+
for conn in self._get_selected_connections(conn_names):
243249
rval[conn.get_name()] = conn.search_variables(query)
244250
return rval
245251

246-
def resources(self) -> dict:
252+
def resources(self, conn_names: list[str] = None) -> dict:
247253
"""
248254
List available resource names from the data repository.
249255
256+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
250257
:return: The available resource names from the data repository, per remote server name
251258
"""
252259
rval = {}
253-
for conn in self.conns:
260+
for conn in self._get_selected_connections(conn_names):
254261
rval[conn.get_name()] = conn.list_resources()
255262
return rval
256263

257-
def profiles(self) -> dict:
264+
def profiles(self, conn_names: list[str] = None) -> dict:
258265
"""
259266
List available DataSHIELD profile names in the data repository.
260267
268+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
261269
:return: The available DataSHIELD profile names in the data repository, per remote server name
262270
"""
263271
rval = {}
264-
for conn in self.conns:
272+
for conn in self._get_selected_connections(conn_names):
265273
rval[conn.get_name()] = conn.list_profiles()
266274
return rval
267275

268-
def packages(self) -> dict:
276+
def packages(self, conn_names: list[str] = None) -> dict:
269277
"""
270278
Get the list of DataSHIELD packages with their version, that have been configured on the remote data repository.
271279
280+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
272281
:return: The list of DataSHIELD packages with their version, that have been configured on the remote data repository, per remote server name
273282
"""
274283
rval = {}
275-
for conn in self.conns:
284+
for conn in self._get_selected_connections(conn_names):
276285
rval[conn.get_name()] = conn.list_packages()
277286
return rval
278287

279-
def methods(self, type: str = "aggregate") -> dict:
288+
def methods(self, type: str = "aggregate", conn_names: list[str] = None) -> dict:
280289
"""
281290
Get the list of DataSHIELD methods that have been configured on the remote data repository.
282291
283292
:param type: The type of method, either "aggregate" (default) or "assign"
293+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
284294
:return: The list of DataSHIELD methods that have been configured on the remote data repository, per remote server name
285295
"""
286296
rval = {}
287-
for conn in self.conns:
297+
for conn in self._get_selected_connections(conn_names):
288298
rval[conn.get_name()] = conn.list_methods(type)
289299
return rval
290300

291301
#
292302
# Workspaces
293303
#
294304

295-
def workspaces(self) -> dict:
305+
def workspaces(self, conn_names: list[str] = None) -> dict:
296306
"""
297307
Get the list of DataSHIELD workspaces, that have been saved on the remote data repository.
298308
309+
:param conn_names: The optional list of connection names to query. If not defined, all opened connections are queried.
299310
:return: The list of DataSHIELD workspaces, that have been saved on the remote data repository, per remote server name
300311
"""
301312
rval = {}
302-
for conn in self.conns:
313+
for conn in self._get_selected_connections(conn_names):
303314
rval[conn.get_name()] = conn.list_workspaces()
304315
return rval
305316

306-
def workspace_save(self, name: str) -> None:
317+
def workspace_save(self, name: str, conn_names: list[str] = None) -> None:
307318
"""
308319
Save the DataSHIELD R session in a workspace on the remote data repository.
309320
310321
:param name: The name of the workspace
322+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
311323
"""
312-
for conn in self.conns:
324+
for conn in self._get_selected_connections(conn_names):
313325
conn.save_workspace(f"{conn.get_name()}:{name}")
314326

315-
def workspace_restore(self, name: str) -> None:
327+
def workspace_restore(self, name: str, conn_names: list[str] = None) -> None:
316328
"""
317329
Restore a saved DataSHIELD R session from the remote data repository. When restoring a workspace,
318330
any existing symbol or file with same name will be overridden.
319331
320332
:param name: The name of the workspace
333+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
321334
"""
322-
for conn in self.conns:
335+
for conn in self._get_selected_connections(conn_names):
323336
conn.restore_workspace(f"{conn.get_name()}:{name}")
324337

325-
def workspace_rm(self, name: str) -> None:
338+
def workspace_rm(self, name: str, conn_names: list[str] = None) -> None:
326339
"""
327340
Remove a DataSHIELD workspace from the remote data repository. Ignored if no
328341
such workspace exists.
329342
330343
:param name: The name of the workspace
344+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
331345
"""
332-
for conn in self.conns:
346+
for conn in self._get_selected_connections(conn_names):
333347
conn.rm_workspace(f"{conn.get_name()}:{name}")
334348

335349
#
@@ -358,6 +372,9 @@ def sessions(self) -> dict:
358372
"""
359373
rval = {}
360374
self._init_errors()
375+
if len(self.conns) == 0:
376+
return rval
377+
361378
started_conns = []
362379
excluded_conns = []
363380

@@ -405,11 +422,11 @@ def sessions(self) -> dict:
405422
if len(excluded_conns) > 0:
406423
logging.error(f"Some sessions have been excluded due to errors: {', '.join(excluded_conns)}")
407424
self.conns = [conn for conn in self.conns if conn.get_name() not in excluded_conns]
408-
if len(self.conns) == 0:
425+
if len(self.conns) == len(excluded_conns):
409426
raise DSError("No sessions could be started successfully.")
410427
return rval
411428

412-
def ls(self) -> dict:
429+
def ls(self, conn_names: list[str] = None) -> dict:
413430
"""
414431
After assignments have been performed, list the symbols that live in the DataSHIELD R session on the server side.
415432
@@ -418,7 +435,7 @@ def ls(self) -> dict:
418435
# ensure sessions are started and available
419436
self.sessions()
420437
rval = {}
421-
for conn in self.conns:
438+
for conn in self._get_selected_connections(conn_names):
422439
try:
423440
rval[conn.get_name()] = conn.list_symbols()
424441
except Exception as e:
@@ -427,15 +444,16 @@ def ls(self) -> dict:
427444
self._check_errors()
428445
return rval
429446

430-
def rm(self, symbol: str) -> None:
447+
def rm(self, symbol: str, conn_names: list[str] = None) -> None:
431448
"""
432449
Remove a symbol from remote servers.
433450
434451
:param symbol: The name of the symbol to remove
452+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
435453
"""
436454
# ensure sessions are started and available
437455
self.sessions()
438-
for conn in self.conns:
456+
for conn in self._get_selected_connections(conn_names):
439457
try:
440458
conn.rm_symbol(symbol)
441459
except Exception as e:
@@ -452,6 +470,7 @@ def assign_table(
452470
identifiers: str = None,
453471
id_name: str = None,
454472
asynchronous: bool = True,
473+
conn_names: list[str] = None,
455474
) -> None:
456475
"""
457476
Assign a data table from the data repository to a symbol in the DataSHIELD R session.
@@ -460,11 +479,12 @@ def assign_table(
460479
:param table: The default name of the table to assign
461480
:param tables: The name of the table to assign, per server name. If not defined, 'table' is used.
462481
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
482+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
463483
"""
464484
# ensure sessions are started and available
465485
self.sessions()
466486
cmd = {}
467-
for conn in self.conns:
487+
for conn in self._get_selected_connections(conn_names):
468488
name = table
469489
if tables and conn.get_name() in tables:
470490
name = tables[conn.get_name()]
@@ -478,7 +498,12 @@ def assign_table(
478498
self._check_errors()
479499

480500
def assign_resource(
481-
self, symbol: str, resource: str = None, resources: dict = None, asynchronous: bool = True
501+
self,
502+
symbol: str,
503+
resource: str = None,
504+
resources: dict = None,
505+
asynchronous: bool = True,
506+
conn_names: list[str] = None,
482507
) -> None:
483508
"""
484509
Assign a resource from the data repository to a symbol in the DataSHIELD R session.
@@ -487,11 +512,12 @@ def assign_resource(
487512
:param resource: The default name of the resource to assign
488513
:param resources: The name of the resource to assign, per server name. If not defined, 'resource' is used.
489514
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
515+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
490516
"""
491517
# ensure sessions are started and available
492518
self.sessions()
493519
cmd = {}
494-
for conn in self.conns:
520+
for conn in self._get_selected_connections(conn_names):
495521
name = resource
496522
if resources and conn.get_name() in resources:
497523
name = resources[conn.get_name()]
@@ -504,18 +530,19 @@ def assign_resource(
504530
self._do_wait(cmd)
505531
self._check_errors()
506532

507-
def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None:
533+
def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> None:
508534
"""
509535
Assign the result of the evaluation of an expression to a symbol in the DataSHIELD R session.
510536
511537
:param symbol: The name of the destination symbol
512538
:param expr: The R expression to evaluate and which result will be assigned
513539
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
540+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
514541
"""
515542
# ensure sessions are started and available
516543
self.sessions()
517544
cmd = {}
518-
for conn in self.conns:
545+
for conn in self._get_selected_connections(conn_names):
519546
try:
520547
res = conn.assign_expr(symbol, expr, asynchronous)
521548
cmd[conn.get_name()] = res
@@ -524,20 +551,21 @@ def assign_expr(self, symbol: str, expr: str, asynchronous: bool = True) -> None
524551
self._do_wait(cmd)
525552
self._check_errors()
526553

527-
def aggregate(self, expr: str, asynchronous: bool = True) -> dict:
554+
def aggregate(self, expr: str, asynchronous: bool = True, conn_names: list[str] = None) -> dict:
528555
"""
529556
Aggregate some data from the DataSHIELD R session using a valid R expression. The
530557
aggregation expression must satisfy the data repository's DataSHIELD configuration.
531558
532559
:param expr: The R expression to evaluate and which result will be returned
533560
:param asynchronous: Whether the operation is asynchronous (if supported by the DataSHIELD server)
561+
:param conn_names: The optional list of connection names to apply this operation to. If not defined, all opened connections are used.
534562
:return: The result of the aggregation expression evaluation, per remote server name
535563
"""
536564
# ensure sessions are started and available
537565
self.sessions()
538566
cmd = {}
539567
rval = {}
540-
for conn in self.conns:
568+
for conn in self._get_selected_connections(conn_names):
541569
try:
542570
res = conn.aggregate(expr, asynchronous)
543571
cmd[conn.get_name()] = res
@@ -573,6 +601,20 @@ def _do_wait(self, cmd: dict) -> dict:
573601
time.sleep(0.1)
574602
return rval
575603

604+
def _get_selected_connections(self, conn_names: list[str] = None) -> list[DSConnection]:
605+
"""
606+
Get the list of opened connections, optionally filtered by connection names.
607+
608+
:param conn_names: The optional list of connection names to select.
609+
:return: The list of selected opened connections
610+
"""
611+
if not self.conns:
612+
return []
613+
if conn_names is None:
614+
return self.conns
615+
selected_names = set(conn_names)
616+
return [conn for conn in self.conns if conn.get_name() in selected_names]
617+
576618
def _init_errors(self) -> None:
577619
"""
578620
Prepare for storing errors.

0 commit comments

Comments
 (0)