Skip to content

Commit 9014f17

Browse files
authored
NiFi scripts: lazy cerner blob decompression possible fix. (#58)
* NiFi scripts: lazy cerner blob decompression possible fix. * NiFi processors: added more sanity checks to cerner processor. * NiFi scripts: more Cerner blob processing attempted fixes. * NiFi scripts: Cerner blob decompression updates. * NiFi: cerner processor added extra debug attrs. * NiFI: cerner blob processor update. * NiFI: base nifi proc update. * Nifi: cerner processor corrections. * NiFi: base nifi proc updates. logging fixes for cerner decompression. * NiFI: cerner blob decompress revert last commir. * NiFi scripts: error handling fix.
1 parent 75c8073 commit 9014f17

4 files changed

Lines changed: 183 additions & 67 deletions

File tree

nifi/user_python_extensions/record_decompress_cerner_blob.py

Lines changed: 119 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ class Java:
3333

3434
class ProcessorDetails:
3535
version = '0.0.1'
36+
description = "Decompresses Cerner LZW compressed blobs from a JSON input stream"
37+
tags = ["cerner", "oracle", "blob"]
3638

3739
def __init__(self, jvm: JVMView):
3840
super().__init__(jvm)
@@ -110,6 +112,7 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr
110112
"""
111113

112114
output_contents: list = []
115+
attributes: dict = {k: str(v) for k, v in flowFile.getAttributes().items()}
113116

114117
try:
115118
self.process_context = context
@@ -118,7 +121,7 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr
118121
# read avro record
119122
input_raw_bytes: bytes | bytearray = flowFile.getContentsAsBytes()
120123

121-
records = []
124+
records: list | dict = []
122125

123126
try:
124127
records = json.loads(input_raw_bytes.decode())
@@ -131,35 +134,70 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr
131134
try:
132135
records = json.loads(input_raw_bytes.decode("windows-1252"))
133136
except json.JSONDecodeError as e:
134-
self.logger.error(f"Error decoding JSON: {str(e)} \n with windows-1252")
135-
raise
137+
return self.build_failure_result(
138+
flowFile,
139+
ValueError(f"Error decoding JSON: {str(e)} \n with windows-1252"),
140+
attributes=attributes,
141+
contents=input_raw_bytes,
142+
)
136143

137144
if not isinstance(records, list):
138145
records = [records]
139146

140147
if not records:
141-
raise ValueError("No records found in JSON input")
148+
return self.build_failure_result(
149+
flowFile,
150+
ValueError("No records found in JSON input"),
151+
attributes=attributes,
152+
contents=input_raw_bytes,
153+
)
154+
155+
# sanity check: blobs are from the same document_id
156+
doc_ids: set = {str(r.get(self.document_id_field_name, "")) for r in records}
157+
if len(doc_ids) > 1:
158+
return self.build_failure_result(
159+
flowFile,
160+
ValueError(f"Multiple document IDs in one FlowFile: {list(doc_ids)}"),
161+
attributes=attributes,
162+
contents=input_raw_bytes,
163+
)
142164

143-
concatenated_blob_sequence_order = {}
144-
output_merged_record = {}
165+
concatenated_blob_sequence_order: dict = {}
166+
output_merged_record: dict = {}
145167

146-
have_any_sequence = any(self.blob_sequence_order_field_name in record for record in records)
147-
have_any_no_sequence = any(self.blob_sequence_order_field_name not in record for record in records)
168+
have_any_sequence: bool = any(self.blob_sequence_order_field_name in record for record in records)
169+
have_any_no_sequence: bool = any(self.blob_sequence_order_field_name not in record for record in records)
148170

149171
if have_any_sequence and have_any_no_sequence:
150-
raise ValueError(
151-
f"Mixed records: some have '{self.blob_sequence_order_field_name}', some don't. "
152-
"Cannot safely reconstruct blob stream."
172+
return self.build_failure_result(
173+
flowFile,
174+
ValueError(
175+
f"Mixed records: some have '{self.blob_sequence_order_field_name}', some don't. "
176+
"Cannot safely reconstruct blob stream."
177+
),
178+
attributes=attributes,
179+
contents=input_raw_bytes,
153180
)
154181

155182
for record in records:
156183
if self.binary_field_name not in record or record[self.binary_field_name] in (None, ""):
157-
raise ValueError(f"Missing '{self.binary_field_name}' in a record")
184+
return self.build_failure_result(
185+
flowFile,
186+
ValueError(f"Missing '{self.binary_field_name}' in a record"),
187+
attributes=attributes,
188+
contents=input_raw_bytes,
189+
)
158190

159191
if have_any_sequence:
160192
seq = int(record[self.blob_sequence_order_field_name])
161193
if seq in concatenated_blob_sequence_order:
162-
raise ValueError(f"Duplicate {self.blob_sequence_order_field_name}: {seq}")
194+
return self.build_failure_result(
195+
flowFile,
196+
ValueError(f"Duplicate {self.blob_sequence_order_field_name}: {seq}"),
197+
attributes=attributes,
198+
contents=input_raw_bytes,
199+
)
200+
163201
concatenated_blob_sequence_order[seq] = record[self.binary_field_name]
164202
else:
165203
# no sequence anywhere: preserve record order (0..n-1)
@@ -174,48 +212,100 @@ def transform(self, context: ProcessContext, flowFile: JavaObject) -> FlowFileTr
174212

175213
full_compressed_blob = bytearray()
176214

177-
for k in sorted(concatenated_blob_sequence_order.keys()):
215+
# double check to make sure there is no gap in the blob sequence, i.e missing blob.
216+
order_of_blobs_keys = sorted(concatenated_blob_sequence_order.keys())
217+
for i in range(1, len(order_of_blobs_keys)):
218+
if order_of_blobs_keys[i] != order_of_blobs_keys[i-1] + 1:
219+
return self.build_failure_result(
220+
flowFile,
221+
ValueError(
222+
f"Sequence gap: missing {order_of_blobs_keys[i-1] + 1} "
223+
f"(have {order_of_blobs_keys[i-1]} then {order_of_blobs_keys[i]})"
224+
),
225+
attributes=attributes,
226+
contents=input_raw_bytes,
227+
)
228+
229+
for k in order_of_blobs_keys:
178230
v = concatenated_blob_sequence_order[k]
179231

232+
temporary_blob: bytes = b""
233+
180234
if self.binary_field_source_encoding == "base64":
181235
if not isinstance(v, str):
182-
raise ValueError(f"Expected base64 string in {self.binary_field_name} for part {k}, got {type(v)}")
236+
return self.build_failure_result(
237+
flowFile,
238+
ValueError(
239+
f"Expected base64 string in {self.binary_field_name} for part {k}, got {type(v)}"
240+
),
241+
attributes=attributes,
242+
contents=input_raw_bytes,
243+
)
183244
try:
184245
temporary_blob = base64.b64decode(v, validate=True)
185246
except Exception as e:
186-
raise ValueError(f"Error decoding base64 blob part {k}: {e}")
247+
return self.build_failure_result(
248+
flowFile,
249+
ValueError(f"Error decoding base64 blob part {k}: {e}"),
250+
attributes=attributes,
251+
contents=input_raw_bytes,
252+
)
187253
else:
188254
# raw bytes path
189255
if isinstance(v, (bytes, bytearray)):
190256
temporary_blob = v
191257
else:
192-
raise ValueError(f"Expected bytes in {self.binary_field_name} for part {k}, got {type(v)}")
193-
258+
return self.build_failure_result(
259+
flowFile,
260+
ValueError(
261+
f"Expected bytes in {self.binary_field_name} for part {k}, got {type(v)}"
262+
),
263+
attributes=attributes,
264+
contents=input_raw_bytes,
265+
)
266+
194267
full_compressed_blob.extend(temporary_blob)
195268

269+
# build / add new attributes to dict before doing anything else to have some trace.
270+
attributes["document_id_field_name"] = str(self.document_id_field_name)
271+
attributes["document_id"] = str(output_merged_record.get(self.document_id_field_name, ""))
272+
attributes["binary_field"] = str(self.binary_field_name)
273+
attributes["output_text_field_name"] = str(self.output_text_field_name)
274+
attributes["mime.type"] = "application/json"
275+
attributes["blob_parts"] = str(len(order_of_blobs_keys))
276+
attributes["blob_seq_min"] = str(order_of_blobs_keys[0]) if order_of_blobs_keys else ""
277+
attributes["blob_seq_max"] = str(order_of_blobs_keys[-1]) if order_of_blobs_keys else ""
278+
attributes["compressed_len"] = str(len(full_compressed_blob))
279+
attributes["compressed_head_hex"] = bytes(full_compressed_blob[:16]).hex()
280+
196281
try:
197282
decompress_blob = DecompressLzwCernerBlob()
198283
decompress_blob.decompress(full_compressed_blob)
199-
output_merged_record[self.binary_field_name] = decompress_blob.output_stream
284+
output_merged_record[self.binary_field_name] = bytes(decompress_blob.output_stream)
200285
except Exception as exception:
201-
self.logger.error(f"Error decompressing cerner blob: {str(exception)} \n")
202-
raise exception
286+
return self.build_failure_result(
287+
flowFile,
288+
exception=exception,
289+
attributes=attributes,
290+
include_flowfile_attributes=False,
291+
contents=input_raw_bytes
292+
)
203293

204294
if self.output_mode == "base64":
205295
output_merged_record[self.binary_field_name] = \
206296
base64.b64encode(output_merged_record[self.binary_field_name]).decode(self.output_charset)
207297

208298
output_contents.append(output_merged_record)
209299

210-
attributes: dict = {k: str(v) for k, v in flowFile.getAttributes().items()}
211-
attributes["document_id_field_name"] = str(self.document_id_field_name)
212-
attributes["binary_field"] = str(self.binary_field_name)
213-
attributes["output_text_field_name"] = str(self.output_text_field_name)
214-
attributes["mime.type"] = "application/json"
215-
216-
return FlowFileTransformResult(relationship="success",
300+
return FlowFileTransformResult(relationship=self.REL_SUCCESS,
217301
attributes=attributes,
218302
contents=json.dumps(output_contents).encode("utf-8"))
219303
except Exception as exception:
220304
self.logger.error("Exception during flowfile processing: " + traceback.format_exc())
221-
raise exception
305+
return self.build_failure_result(
306+
flowFile,
307+
exception,
308+
attributes=attributes,
309+
contents=locals().get("input_raw_bytes", flowFile.getContentsAsBytes()),
310+
include_flowfile_attributes=False
311+
)
Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
class LzwItem:
32
def __init__(self, _prefix: int = 0, _suffix: int = 0) -> None:
43
self.prefix = _prefix
@@ -9,7 +8,7 @@ class DecompressLzwCernerBlob:
98
def __init__(self) -> None:
109
self.MAX_CODES: int = 8192
1110
self.tmp_decompression_buffer: list[int] = [0] * self.MAX_CODES
12-
self.lzw_lookup_table: list[LzwItem] = [LzwItem()] * self.MAX_CODES
11+
self.lzw_lookup_table: list[LzwItem] = [LzwItem() for _ in range(self.MAX_CODES)]
1312
self.tmp_buffer_index: int = 0
1413
self.current_byte_buffer_index: int = 0
1514

@@ -21,19 +20,24 @@ def save_to_lookup_table(self, compressed_code: int):
2120
self.tmp_buffer_index = -1
2221
while compressed_code >= 258:
2322
self.tmp_buffer_index += 1
24-
self.tmp_decompression_buffer[self.tmp_buffer_index] = \
25-
self.lzw_lookup_table[compressed_code].suffix
23+
self.tmp_decompression_buffer[self.tmp_buffer_index] = self.lzw_lookup_table[compressed_code].suffix
2624
compressed_code = self.lzw_lookup_table[compressed_code].prefix
2725

2826
self.tmp_buffer_index += 1
2927
self.tmp_decompression_buffer[self.tmp_buffer_index] = compressed_code
3028

31-
for i in reversed(list(range(self.tmp_buffer_index + 1))):
32-
self.output_stream.append(self.tmp_decompression_buffer[i])
29+
for i in reversed(range(self.tmp_buffer_index + 1)):
30+
v = self.tmp_decompression_buffer[i]
31+
if not (0 <= v <= 255):
32+
raise ValueError(f"Invalid output byte {v} (expected 0..255)")
33+
self.output_stream.append(v)
3334

34-
def decompress(self, input_stream: bytearray = bytearray()):
35+
def decompress(self, input_stream: bytearray):
36+
if not input_stream:
37+
raise ValueError("Empty input_stream")
3538

3639
byte_buffer_index: int = 0
40+
self.output_stream = bytearray()
3741

3842
# used for bit shifts
3943
shift: int = 1
@@ -49,46 +53,56 @@ def decompress(self, input_stream: bytearray = bytearray()):
4953

5054
while True:
5155
if current_shift >= 9:
52-
5356
current_shift -= 8
54-
5557
if first_code != 0:
5658
byte_buffer_index += 1
59+
if byte_buffer_index >= len(input_stream):
60+
raise ValueError("Truncated input_stream")
61+
5762
middle_code = input_stream[byte_buffer_index]
5863

59-
first_code = (first_code << current_shift +
60-
8) | (middle_code << current_shift)
64+
first_code = (first_code << (current_shift + 8)) | (middle_code << current_shift)
6165

6266
byte_buffer_index += 1
67+
68+
if byte_buffer_index >= len(input_stream):
69+
raise ValueError("Truncated input_stream")
70+
6371
middle_code = input_stream[byte_buffer_index]
6472

6573
tmp_code = middle_code >> (8 - current_shift)
6674
lookup_index = first_code | tmp_code
67-
6875
skip_flag = True
6976
else:
7077
byte_buffer_index += 1
78+
if byte_buffer_index >= len(input_stream):
79+
raise ValueError("Truncated input_stream")
7180
first_code = input_stream[byte_buffer_index]
7281
byte_buffer_index += 1
82+
if byte_buffer_index >= len(input_stream):
83+
raise ValueError("Truncated input_stream")
7384
middle_code = input_stream[byte_buffer_index]
7485
else:
7586
byte_buffer_index += 1
87+
if byte_buffer_index >= len(input_stream):
88+
raise ValueError("Truncated input_stream")
7689
middle_code = input_stream[byte_buffer_index]
7790

7891
if not skip_flag:
79-
lookup_index = (first_code << current_shift) | (
80-
middle_code >> 8 - current_shift)
92+
lookup_index = (first_code << current_shift) | (middle_code >> (8 - current_shift))
8193

8294
if lookup_index == 256:
8395
shift = 1
84-
current_shift += 1
85-
first_code = input_stream[byte_buffer_index]
86-
96+
current_shift = 1
97+
previous_code = 0
98+
skip_flag = False
99+
87100
self.tmp_decompression_buffer = [0] * self.MAX_CODES
88101
self.tmp_buffer_index = 0
89-
90-
self.lzw_lookup_table = [LzwItem()] * self.MAX_CODES
102+
self.lzw_lookup_table = [LzwItem() for _ in range(self.MAX_CODES)]
91103
self.code_count = 257
104+
105+
first_code = input_stream[byte_buffer_index]
92106
continue
93107

94108
elif lookup_index == 257: # EOF marker
@@ -99,18 +113,18 @@ def decompress(self, input_stream: bytearray = bytearray()):
99113
# skipit part
100114
if previous_code == 0:
101115
self.tmp_decompression_buffer[0] = lookup_index
116+
102117
if lookup_index < self.code_count:
103118
self.save_to_lookup_table(lookup_index)
104119
if self.code_count < self.MAX_CODES:
105-
self.lzw_lookup_table[self.code_count] = LzwItem(
106-
previous_code,
107-
self.tmp_decompression_buffer[self.tmp_buffer_index])
120+
self.lzw_lookup_table[self.code_count] = \
121+
LzwItem(previous_code, self.tmp_decompression_buffer[self.tmp_buffer_index])
108122
self.code_count += 1
109123
else:
110-
self.lzw_lookup_table[self.code_count] = LzwItem(
111-
previous_code,
112-
self.tmp_decompression_buffer[self.tmp_buffer_index])
113-
self.code_count += 1
124+
if self.code_count < self.MAX_CODES:
125+
self.lzw_lookup_table[self.code_count] = \
126+
LzwItem(previous_code, self.tmp_decompression_buffer[self.tmp_buffer_index])
127+
self.code_count += 1
114128
self.save_to_lookup_table(lookup_index)
115129
# end of skipit
116130

@@ -120,4 +134,5 @@ def decompress(self, input_stream: bytearray = bytearray()):
120134
if self.code_count in [511, 1023, 2047, 4095]:
121135
shift += 1
122136
current_shift += 1
137+
123138
previous_code = lookup_index

0 commit comments

Comments
 (0)