Skip to content

Commit e784f9b

Browse files
committed
Fix access to result_set and QueryPolyResult construction
1 parent eb87e7e commit e784f9b

3 files changed

Lines changed: 26 additions & 20 deletions

File tree

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[metadata]
22
name = ipython-polypheny
3-
version = 0.1.2
3+
version = 0.1.3
44
description = Access Polypheny via IPython magics
55
long_description = file: README.md
66
long_description_content_type = text/markdown

src/poly/poly_magic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def expand_variables(self, template):
145145

146146
def separate_args(line, cell, split_str=":"):
147147
"""
148-
Finds the first occurence of an element of termination_strings in line. The line is split after this element and
148+
Finds the first occurrence of an element of termination_strings in line. The line is split after this element and
149149
the two parts are returned.
150150
"""
151151
if cell is None:

src/poly/poly_result.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
def build_result(raw_result: Union[str, dict, List[dict]]):
77
result_set = get_result_dict(raw_result)
88

9-
if 'error' in result_set:
9+
if result_set.get('error') is not None:
1010
return ErrorPolyResult(result_set)
1111
if 'header' not in result_set:
1212
return InfoPolyResult(result_set)
@@ -35,33 +35,39 @@ def get_type_from_string(type_string):
3535
return float
3636
if cleaned in ['ARRAY', 'JSON', 'DOCUMENT', 'DOCUMENT NOT NULL']:
3737
return json.loads
38-
if cleaned.startswith('PATH') or cleaned.startswith('NODE'):
38+
if cleaned.startswith('PATH') or cleaned.startswith('NODE') or cleaned.startswith('EDGE'):
3939
return json.loads
4040
return None
4141

4242

43-
def cast_data(raw_data, header):
44-
data = [row[:] for row in raw_data] # Create shallow copy
45-
data_types = [get_type_from_string(col['dataType']) for col in header]
46-
47-
for col_idx, col_type in enumerate(data_types):
48-
if col_type is None:
49-
continue
50-
for row_idx in range(len(data)):
51-
try:
52-
data[row_idx][col_idx] = col_type(data[row_idx][col_idx])
53-
except (TypeError, json.JSONDecodeError):
54-
pass
43+
def cast_data(raw_data, header, data_model):
44+
if data_model in ['RELATIONAL', 'GRAPH']:
45+
data = [row[:] for row in raw_data] # Create shallow copy
46+
data_types = [get_type_from_string(col['dataType']) for col in header]
47+
48+
for col_idx, col_type in enumerate(data_types):
49+
if col_type is None:
50+
continue
51+
for row_idx in range(len(data)):
52+
try:
53+
data[row_idx][col_idx] = col_type(data[row_idx][col_idx])
54+
except (TypeError, json.JSONDecodeError):
55+
pass
56+
elif data_model == 'DOCUMENT':
57+
# transform documents to rows with 1 column
58+
data = [[json.loads(doc[0] if isinstance(doc, list) else doc)] for doc in raw_data]
59+
else:
60+
raise ValueError(f'Unsupported dataModel: {data_model}')
5561
return data
5662

5763

5864
class QueryPolyResult(list): # Data is stored in a nested list. For simplicity, we do not use Numpy, but Python lists
5965
def __init__(self, result_set):
6066
self.result_set = result_set
61-
self.type = result_set['namespaceType']
67+
self.type = result_set['dataModel']
6268
self._header = result_set['header']
6369
self.keys = [col['name'] for col in self._header]
64-
self._data = cast_data(result_set['data'], self._header)
70+
self._data = cast_data(result_set['data'], self._header, self.type)
6571
self._pretty = PrettyTable(self.keys)
6672
self._pretty.add_rows(self._data)
6773
self._pretty.set_style(PLAIN_COLUMNS)
@@ -87,7 +93,7 @@ def __init__(self, result_set):
8793

8894
def __repr__(self):
8995
rows = ['Successfully executed:',
90-
'Query:'.ljust(30) + str(self.result_set['generatedQuery'])
96+
'Query:'.ljust(30) + str(self.result_set['query'])
9197
]
9298
return "\n".join(rows)
9399

@@ -98,7 +104,7 @@ def __init__(self, result_set):
98104

99105
def __repr__(self):
100106
rows = ['ERROR:',
101-
'Query:'.ljust(30) + str(self.result_set['generatedQuery']),
107+
'Query:'.ljust(30) + str(self.result_set['query']),
102108
'Message:'.ljust(30) + str(self.result_set['error'])
103109
]
104110
return "\n".join(rows)

0 commit comments

Comments
 (0)