-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathrest_api_recipe_session.py
More file actions
200 lines (188 loc) · 9.4 KB
/
rest_api_recipe_session.py
File metadata and controls
200 lines (188 loc) · 9.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
from dataikuapi.utils import DataikuException
from rest_api_client import RestAPIClient
from safe_logger import SafeLogger
from dku_utils import parse_keys_for_json, get_value_from_path, decode_csv_data, de_NaN, decode_bytes
from dku_constants import DKUConstants
import copy
import json
import requests
import collections
logger = SafeLogger("api-connect plugin", forbidden_keys=DKUConstants.FORBIDDEN_KEYS)
class RestApiRecipeSession:
def __init__(self, custom_key_values, credential_parameters, secure_credentials, endpoint_parameters, extraction_key, parameter_columns, parameter_renamings,
display_metadata=False,
maximum_number_rows=-1, behaviour_when_error=None):
self.custom_key_values = custom_key_values
self.credential_parameters = credential_parameters
self.secure_credentials = secure_credentials
self.endpoint_parameters = endpoint_parameters
self.extraction_key = extraction_key
self.client = None
self.initial_parameter_columns = None
self.column_to_parameter_dict = self.get_column_to_parameter_dict(parameter_columns, parameter_renamings)
self.display_metadata = display_metadata
self.maximum_number_rows = maximum_number_rows
self.is_row_limit = (self.maximum_number_rows > 0)
self.behaviour_when_error = behaviour_when_error or "add-error-column"
self.can_raise = self.behaviour_when_error == "raise"
@staticmethod
def get_column_to_parameter_dict(parameter_columns, parameter_renamings):
column_to_parameter_dict = {}
for parameter_column in parameter_columns:
if parameter_column in parameter_renamings:
column_to_parameter_dict[parameter_column] = parameter_renamings[parameter_column]
else:
column_to_parameter_dict[parameter_column] = parameter_column
return column_to_parameter_dict
def process_dataframe(self, input_parameters_dataframe, is_raw_output):
results = []
time_last_request = None
session = requests.Session()
for index, input_parameters_row in input_parameters_dataframe.iterrows():
rows_count = 0
self.initial_parameter_columns = {}
for column_name in self.column_to_parameter_dict:
parameter_name = self.column_to_parameter_dict[column_name]
self.initial_parameter_columns.update(
{
parameter_name: de_NaN(input_parameters_row.get(column_name))
}
)
updated_endpoint_parameters = copy.deepcopy(self.endpoint_parameters)
updated_endpoint_parameters.update(self.initial_parameter_columns)
logger.info("Processing row #{}, creating client with credential={}, updated_endpoint={}, custom_key_values={}".format(
index + 1,
logger.filter_secrets(self.credential_parameters),
updated_endpoint_parameters,
self.custom_key_values
))
self.client = RestAPIClient(
self.credential_parameters,
self.secure_credentials,
updated_endpoint_parameters,
custom_key_values=self.custom_key_values,
session=session,
behaviour_when_error=self.behaviour_when_error
)
self.client.time_last_request = time_last_request
while self.client.has_more_data():
page_results = self.retrieve_next_page(is_raw_output)
results.extend(page_results)
rows_count += len(page_results)
if self.is_row_limit and rows_count >= self.maximum_number_rows:
break
time_last_request = self.client.time_last_request
return results
def retrieve_next_page(self, is_raw_output):
page_rows = []
logger.info("retrieve_next_page: Calling next page")
json_response = self.client.paginated_api_call(can_raise_exeption=self.can_raise)
default_dict = {
DKUConstants.REPONSE_ERROR_KEY: ""
} if self.behaviour_when_error == "keep-error-column" else {}
if isinstance(json_response, dict) and DKUConstants.REPONSE_ERROR_KEY in default_dict:
default_dict[DKUConstants.REPONSE_ERROR_KEY] = json_response.get(DKUConstants.REPONSE_ERROR_KEY, None)
metadata = self.client.get_metadata() if self.display_metadata else default_dict
is_api_returning_dict = True
if self.extraction_key:
data_rows = get_value_from_path(json_response, self.extraction_key.split("."), can_raise=False)
if data_rows is None:
if self.behaviour_when_error == "ignore":
return []
error_message = "Extraction key '{}' was not found in the incoming data".format(self.extraction_key)
if self.can_raise:
raise DataikuException(error_message)
elif DKUConstants.REPONSE_ERROR_KEY in metadata:
return [metadata]
else:
return self.format_page_rows([{DKUConstants.REPONSE_ERROR_KEY: error_message}], is_raw_output, metadata)
page_rows.extend(self.format_page_rows(data_rows, is_raw_output, metadata))
else:
# Todo: check api_response key is free and add something overwise
base_row = copy.deepcopy(metadata)
if is_raw_output:
assert_json(json_response)
if is_error_message(json_response):
base_row.update(parse_keys_for_json(json_response))
else:
base_row.update({
DKUConstants.API_RESPONSE_KEY: json.dumps(json_response)
})
else:
if isinstance(json_response, dict):
base_row.update(parse_keys_for_json(json_response))
elif isinstance(json_response, list):
is_api_returning_dict = False
for row in json_response:
base_row = copy.deepcopy(metadata)
base_row.update(parse_keys_for_json(row))
base_row.update(self.initial_parameter_columns)
page_rows.append(base_row)
else:
decoded_csv_data = decode_csv_data(json_response)
is_api_returning_dict = False
if not decoded_csv_data and json_response:
logger.warning("Data is not in CSV format. Dumping it in text mode.")
decoded_csv_data = [
{
DKUConstants.API_RESPONSE_KEY: "{}".format(
decode_bytes(json_response)
)
}
]
for row in decoded_csv_data:
base_row = copy.deepcopy(metadata)
base_row.update(parse_keys_for_json(row))
base_row.update(self.initial_parameter_columns)
page_rows.append(base_row)
if is_api_returning_dict:
base_row.update(self.initial_parameter_columns)
page_rows.append(base_row)
return page_rows
def format_page_rows(self, data_rows, is_raw_output, metadata=None):
page_rows = []
metadata = metadata or {}
if type(data_rows) in [str, bytes]:
data_rows = decode_csv_data(data_rows)
if type(data_rows) in [list]:
for data_row in data_rows:
base_row = copy.deepcopy(self.initial_parameter_columns)
base_row.update(metadata)
if is_raw_output:
if is_error_message(data_row):
base_row.update(parse_keys_for_json(data_row))
else:
base_row.update({
DKUConstants.API_RESPONSE_KEY: json.dumps(data_row)
})
else:
base_row.update(parse_keys_for_json(data_row))
page_rows.append(base_row)
if type(data_rows) in [dict, collections.OrderedDict]:
base_row = copy.deepcopy(self.initial_parameter_columns)
base_row.update(metadata)
if is_raw_output:
if is_error_message(data_row):
base_row.update({
DKUConstants.API_RESPONSE_KEY: None
})
base_row.update(parse_keys_for_json(data_row))
else:
base_row.update({
DKUConstants.API_RESPONSE_KEY: json.dumps(data_rows)
})
else:
base_row.update(parse_keys_for_json(data_rows))
page_rows.append(base_row)
return page_rows
def is_error_message(jsons_response):
if type(jsons_response) not in [dict, list]:
return False
if DKUConstants.REPONSE_ERROR_KEY in jsons_response and len(jsons_response) == 1:
return True
else:
return False
def assert_json(variable_to_check):
if isinstance(variable_to_check, dict) or isinstance(variable_to_check, list):
return
raise Exception("Returned data is not JSON format. Try again with 'Raw JSON output' un-checked.")