Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1349,25 +1349,30 @@ def alphabetize_dict_recursive(obj):
def get_dataset_documents(uuid):
validate_token_if_auth_header_exists(request)
token = get_internal_token()
excluded_fields = None
include_fields = None
if bool(request.args):
excluded = request.args.get('exclude')
if excluded:
excluded_fields = [
included = request.args.get('include')
if included:
include_fields = [
f.strip().strip("'").strip('"')
for f in excluded.split(',')
for f in included.split(',')
if f.strip()
]
# Validation step to ensure fields are real property names
valid_fields = set(schema_manager.get_persistent_fields())
invalid = [f for f in include_fields if f not in valid_fields]
if invalid:
return bad_request_error(f"Invalid include fields: {invalid}")
else:
return bad_request_error("Missing required parameter: 'include'. Must include a list of properties to be returned.")
else:
return bad_request_error("Missing required parameter: 'include'. Must include a list of properties to be returned.")

# This is a validation step. Because we're allowing excluded fields to be passed from search-api,
# we want to minimally at least make sure these are real property names before using them for
# querying neo4j.
valid_fields = set(schema_manager.get_persistent_fields())
invalid = [f for f in excluded_fields if f not in valid_fields]
if invalid:
return bad_request_error(f"Invalid excluded fields: {invalid}")

entity_record = app_neo4j_queries.get_dataset_documents_raw(neo4j_driver_instance, uuid, excluded_fields=excluded_fields)
entity_record = app_neo4j_queries.get_dataset_documents_raw(
neo4j_driver_instance,
uuid,
included_fields=include_fields
)
if entity_record is None:
return not_found_error(f"Entity {uuid} not found")

Expand Down
18 changes: 6 additions & 12 deletions src/app_neo4j_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,7 @@ def get_source_samples(neo4j_driver, uuid):
are found, or None if the input UUID does not correspond to a supported
entity type.
"""
def get_dataset_documents_raw(neo4j_driver, uuid, excluded_fields=None):
if excluded_fields is None:
excluded_fields = []

def get_dataset_documents_raw(neo4j_driver, uuid, included_fields):
with neo4j_driver.session() as session:
entity_record = session.run("""
MATCH (e:Entity {uuid: $uuid})
Expand All @@ -279,22 +276,19 @@ def get_dataset_documents_raw(neo4j_driver, uuid, excluded_fields=None):
root_label = 'Upload'
else:
return None

projection = "d { .* }"

if excluded_fields:
null_projection = ", ".join(f"{field}: NULL" for field in excluded_fields)
projection = f"d {{ .*, {null_projection} }}"

record = session.run("""
MATCH (root:%s {uuid: $uuid})<-[:%s]-(d:Dataset)
RETURN apoc.map.fromPairs(COLLECT([d.uuid, %s])) AS result
""" % (root_label, relationship, projection), uuid=uuid).single()
WITH apoc.coll.toSet(COLLECT(d)) AS datasets
RETURN [d IN datasets | d { %s }] AS result
""" % (root_label, relationship, ', '.join(f'.{f}' for f in included_fields)),
uuid=uuid).single()

if not record or not record["result"]:
return {}

return {uuid: dict(props) for uuid, props in record["result"].items()}
return {d['uuid']: dict(d) for d in record["result"]}



Expand Down
Loading