Skip to content

Commit f2a6069

Browse files
Fix decimal handling and timestamp tests for Python SDK
1 parent 4f42f9a commit f2a6069

3 files changed

Lines changed: 375 additions & 41 deletions

File tree

paimon-python/pypaimon/table/row/generic_row.py

Lines changed: 101 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,41 @@
1616
# limitations under the License.
1717
################################################################################
1818

19+
import calendar
20+
import decimal
1921
import struct
22+
from dataclasses import dataclass
2023
from datetime import date, datetime, time, timedelta
2124
from decimal import Decimal
2225
from typing import Any, List, Union
2326

24-
from dataclasses import dataclass
25-
2627
from pypaimon.schema.data_types import AtomicType, DataField, DataType
2728
from pypaimon.table.row.binary_row import BinaryRow
28-
from pypaimon.table.row.internal_row import InternalRow, RowKind
2929
from pypaimon.table.row.blob import BlobData
30+
from pypaimon.table.row.internal_row import InternalRow, RowKind
3031

32+
_DECIMAL_CTX = decimal.Context(prec=100, rounding=decimal.ROUND_HALF_UP)
3133

32-
def _decimal_to_unscaled(d: Decimal, scale: int) -> int:
33-
"""Convert a Decimal to its unscaled integer value without precision loss.
34-
Raises ArithmeticError if the value has more fractional digits than scale."""
35-
sign, digits, exponent = d.as_tuple()
34+
35+
def _decimal_to_unscaled_with_check(d: Decimal, precision: int, scale: int):
36+
"""Round decimal with HALF_UP, check precision overflow, and return unscaled value.
37+
Returns (unscaled_int, True) on overflow, (unscaled_int, False) on success."""
38+
rounded = d.quantize(Decimal(10) ** -scale, context=_DECIMAL_CTX)
39+
sign, digits, exponent = rounded.as_tuple()
40+
# Precision overflow check
41+
if rounded != 0 and len(digits) > precision:
42+
return 0, True
3643
int_digits = int(''.join(str(x) for x in digits)) if digits != (0,) else 0
3744
shift = exponent + scale
3845
if shift >= 0:
3946
unscaled = int_digits * (10 ** shift)
4047
else:
41-
divisor = 10 ** (-shift)
42-
if int_digits % divisor != 0:
43-
raise ArithmeticError(
44-
f"Decimal {d} has more fractional digits than scale {scale}")
45-
unscaled = int_digits // divisor
46-
return -unscaled if sign else unscaled
48+
unscaled = int_digits // (10 ** (-shift))
49+
return (-unscaled if sign else unscaled), False
4750

4851

4952
def _parse_type_precision_scale(data_type):
50-
"""Parse precision and scale from type string like DECIMAL(38, 10) or TIMESTAMP(6)."""
53+
"""Parse precision and scale from type string like DECIMAL(38, 10)."""
5154
type_str = str(data_type)
5255
if '(' in type_str and ')' in type_str:
5356
try:
@@ -61,6 +64,28 @@ def _parse_type_precision_scale(data_type):
6164
return 0, 0
6265

6366

67+
_EPOCH = datetime(1970, 1, 1)
68+
69+
70+
def _datetime_to_millis_and_nanos(value: datetime):
71+
"""Convert datetime to (epoch_millis, nano_of_millisecond) without float arithmetic."""
72+
epoch_seconds = calendar.timegm(value.timetuple())
73+
millis = epoch_seconds * 1000 + value.microsecond // 1000
74+
nano_of_millisecond = (value.microsecond % 1000) * 1000
75+
return millis, nano_of_millisecond
76+
77+
78+
def _millis_nanos_to_datetime(millis: int, nano_of_millisecond: int = 0) -> datetime:
79+
"""Convert (epoch_millis, nano_of_millisecond) to datetime. Nanos truncated to micros."""
80+
total_micros = millis * 1000 + nano_of_millisecond // 1000
81+
seconds = total_micros // 1_000_000
82+
micros = total_micros % 1_000_000
83+
if micros < 0:
84+
seconds -= 1
85+
micros += 1_000_000
86+
return _EPOCH + timedelta(seconds=seconds, microseconds=micros)
87+
88+
6489
@dataclass
6590
class GenericRow(InternalRow):
6691

@@ -271,26 +296,43 @@ def _unscaled_to_decimal(cls, unscaled_value: int, scale: int) -> Decimal:
271296
return Decimal((sign, digits, -scale))
272297

273298
@classmethod
274-
def _parse_decimal(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType) -> Decimal:
299+
def _parse_decimal(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType):
275300
precision, scale = _parse_type_precision_scale(data_type)
301+
if precision <= 0:
302+
raise ValueError(f"Decimal requires precision > 0, got {precision}")
276303
if precision <= 18:
277-
# Compact format: unscaled long stored directly in fixed part
304+
# Compact: unscaled long in fixed part
278305
unscaled_long = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
279306
return cls._unscaled_to_decimal(unscaled_long, scale)
280307
else:
281-
# Non-compact format: fixed part has (cursor << 32) | byte_length
308+
# Non-compact: (cursor << 32 | byte_length) in fixed part, bytes in var area
282309
offset_and_len = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
283310
cursor = (offset_and_len >> 32) & 0xFFFFFFFF
284311
byte_length = offset_and_len & 0xFFFFFFFF
285312
var_offset = base_offset + cursor
286313
unscaled_bytes = bytes_data[var_offset:var_offset + byte_length]
287314
unscaled_value = int.from_bytes(unscaled_bytes, byteorder='big', signed=True)
288-
return cls._unscaled_to_decimal(unscaled_value, scale)
315+
# Precision overflow returns null
316+
result = cls._unscaled_to_decimal(unscaled_value, scale)
317+
_, digits, _ = result.as_tuple()
318+
if result != 0 and len(digits) > precision:
319+
return None
320+
return result
289321

290322
@classmethod
291323
def _parse_timestamp(cls, bytes_data: bytes, base_offset: int, field_offset: int, data_type: DataType) -> datetime:
292-
millis = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
293-
return datetime.fromtimestamp(millis / 1000.0, tz=None)
324+
precision, _ = _parse_type_precision_scale(data_type)
325+
if precision <= 3:
326+
# Compact: epoch millis in fixed part
327+
millis = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
328+
return _millis_nanos_to_datetime(millis)
329+
else:
330+
# Non-compact: (cursor << 32 | nanoOfMillisecond) in fixed part, millis in var area
331+
offset_and_nanos = struct.unpack('<q', bytes_data[field_offset:field_offset + 8])[0]
332+
nano_of_millisecond = offset_and_nanos & 0xFFFFFFFF
333+
sub_offset = (offset_and_nanos >> 32) & 0xFFFFFFFF
334+
millis = struct.unpack('<q', bytes_data[base_offset + sub_offset:base_offset + sub_offset + 8])[0]
335+
return _millis_nanos_to_datetime(millis, nano_of_millisecond)
294336

295337
@classmethod
296338
def _parse_date(cls, bytes_data: bytes, field_offset: int) -> date:
@@ -339,17 +381,39 @@ def to_bytes(cls, row: Union[GenericRow, BinaryRow]) -> bytes:
339381
raise ValueError(f"BinaryRow only support AtomicType yet, meet {field.type.__class__}")
340382

341383
type_name = field.type.type.upper()
342-
is_var_len_type = any(type_name.startswith(p) for p in ['CHAR', 'VARCHAR', 'STRING',
343-
'BINARY', 'VARBINARY', 'BYTES', 'BLOB'])
384+
is_var_len_type = any(type_name.startswith(p) for p in [
385+
'CHAR', 'VARCHAR', 'STRING', 'BINARY', 'VARBINARY', 'BYTES', 'BLOB'])
344386
is_decimal_type = type_name.startswith('DECIMAL') or type_name.startswith('NUMERIC')
387+
is_timestamp_type = type_name.startswith('TIMESTAMP')
345388
decimal_precision, decimal_scale = _parse_type_precision_scale(field.type) if is_decimal_type else (0, 0)
346389
is_high_precision_decimal = is_decimal_type and decimal_precision > 18
347-
348-
if is_var_len_type or is_high_precision_decimal:
390+
timestamp_precision = _parse_type_precision_scale(field.type)[0] if is_timestamp_type else 0
391+
is_non_compact_timestamp = is_timestamp_type and timestamp_precision > 3
392+
393+
# Precision overflow -> null
394+
if is_decimal_type and value is not None:
395+
d = value if isinstance(value, Decimal) else Decimal(str(value))
396+
unscaled_value, overflow = _decimal_to_unscaled_with_check(d, decimal_precision, decimal_scale)
397+
if overflow:
398+
cls._set_null_bit(fixed_part, 0, i)
399+
struct.pack_into('<q', fixed_part, field_fixed_offset, 0)
400+
continue
401+
402+
if is_non_compact_timestamp:
403+
# Non-compact: millis in var area, (offset << 32 | nanoOfMilli) in fixed part
404+
if value.tzinfo is not None:
405+
raise RuntimeError("datetime tzinfo not supported yet")
406+
ts_millis, nano_of_millisecond = _datetime_to_millis_and_nanos(value)
407+
var_value_bytes = struct.pack('<q', ts_millis)
408+
offset_in_variable_part = current_variable_offset
409+
variable_part_data.append(var_value_bytes)
410+
current_variable_offset += 8
411+
absolute_offset = fixed_part_size + offset_in_variable_part
412+
offset_and_nano = (absolute_offset << 32) | nano_of_millisecond
413+
struct.pack_into('<q', fixed_part, field_fixed_offset, offset_and_nano)
414+
elif is_var_len_type or is_high_precision_decimal:
349415
if is_high_precision_decimal:
350-
d = value if isinstance(value, Decimal) else Decimal(str(value))
351-
unscaled_value = _decimal_to_unscaled(d, decimal_scale)
352-
# Convert to big-endian signed bytes (minimal representation)
416+
# Big-endian signed bytes
353417
if unscaled_value == 0:
354418
value_bytes = b'\x00'
355419
else:
@@ -370,7 +434,7 @@ def to_bytes(cls, row: Union[GenericRow, BinaryRow]) -> bytes:
370434
header_byte = 0x80 | length
371435
fixed_part[field_fixed_offset + 7] = header_byte
372436
else:
373-
# Non-compact decimal uses fixed 16 bytes, others use 8-byte alignment
437+
# Non-compact decimal: fixed 16 bytes; others: 8-byte aligned
374438
if is_high_precision_decimal:
375439
var_length = 16
376440
else:
@@ -428,6 +492,11 @@ def _serialize_field_value(cls, value: Any, data_type: AtomicType) -> bytes:
428492
f"via the variable-length path in to_bytes(), not _serialize_field_value()")
429493
return cls._serialize_decimal(value, data_type)
430494
elif type_name.startswith('TIMESTAMP'):
495+
precision = _parse_type_precision_scale(data_type)[0]
496+
if precision > 3:
497+
raise ValueError(
498+
f"Non-compact timestamp (precision={precision}) must be serialized "
499+
f"via the variable-length path in to_bytes(), not _serialize_field_value()")
431500
return cls._serialize_timestamp(value)
432501
elif type_name in ['DATE']:
433502
return cls._serialize_date(value) + b'\x00' * 4
@@ -466,17 +535,17 @@ def _serialize_double(cls, value: float) -> bytes:
466535

467536
@classmethod
468537
def _serialize_decimal(cls, value: Decimal, data_type: DataType) -> bytes:
469-
"""Serialize compact decimal (precision <= 18) as unscaled long in fixed part."""
470-
_, scale = _parse_type_precision_scale(data_type)
538+
"""Compact decimal: unscaled long in fixed part."""
539+
precision, scale = _parse_type_precision_scale(data_type)
471540
d = value if isinstance(value, Decimal) else Decimal(str(value))
472-
unscaled_value = _decimal_to_unscaled(d, scale)
541+
unscaled_value, _ = _decimal_to_unscaled_with_check(d, precision, scale)
473542
return struct.pack('<q', unscaled_value)
474543

475544
@classmethod
476545
def _serialize_timestamp(cls, value: datetime) -> bytes:
477546
if value.tzinfo is not None:
478547
raise RuntimeError("datetime tzinfo not supported yet")
479-
millis = int(value.timestamp() * 1000)
548+
millis, _ = _datetime_to_millis_and_nanos(value)
480549
return struct.pack('<q', millis)
481550

482551
@classmethod

paimon-python/pypaimon/tests/decimal_test.py

Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from decimal import Decimal
2121

2222
from pypaimon.schema.data_types import AtomicType, DataField
23-
from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer, GenericRowDeserializer
23+
from pypaimon.table.row.generic_row import (GenericRow, GenericRowDeserializer,
24+
GenericRowSerializer)
2425
from pypaimon.table.row.row_kind import RowKind
2526

2627

@@ -113,7 +114,6 @@ def test_decimal_mixed_with_other_types(self):
113114
self.assertEqual(result.values[3], Decimal("12312455.22"))
114115
self.assertAlmostEqual(result.values[4], 3.14)
115116

116-
117117
def test_decimal_compact_binary_format(self):
118118
"""Verify compact decimal binary layout: unscaled long in fixed part."""
119119
fields = [DataField(0, "d", AtomicType("DECIMAL(4, 2)"))]
@@ -148,7 +148,7 @@ def test_decimal_not_compact_binary_format(self):
148148

149149
# cursor should point to the variable area (== fixed_part_size)
150150
self.assertEqual(cursor, fixed_part_size)
151-
# variable area should be exactly 16 bytes (matching Java's cursor += 16)
151+
# variable area should be exactly 16 bytes
152152
var_area = data[cursor:]
153153
self.assertEqual(len(var_area), 16)
154154
# unscaled bytes are big-endian signed
@@ -157,7 +157,6 @@ def test_decimal_not_compact_binary_format(self):
157157
# Decimal("5.55000") with scale=5 => unscaled = 555000
158158
self.assertEqual(unscaled_value, 555000)
159159

160-
161160
def test_decimal_boundary_precision(self):
162161
"""Test boundary: DECIMAL(18, ...) is compact, DECIMAL(19, ...) is non-compact."""
163162
# precision=18: last compact
@@ -197,12 +196,85 @@ def test_decimal_zero_different_scales(self):
197196
result = GenericRowDeserializer.from_bytes(serialized, fields)
198197
self.assertEqual(result.values[0], val)
199198

200-
def test_decimal_truncation_raises(self):
201-
"""Serializing a value with more fractional digits than scale should raise."""
199+
def test_decimal_half_up_rounding(self):
200+
"""Excess fractional digits should be rounded with HALF_UP."""
202201
fields = [DataField(0, "d", AtomicType("DECIMAL(10, 2)"))]
203-
row = GenericRow([Decimal("1.999")], fields, RowKind.INSERT)
204-
with self.assertRaises(ArithmeticError):
205-
GenericRowSerializer.to_bytes(row)
202+
203+
test_cases = [
204+
(Decimal("1.999"), Decimal("2.00")), # .999 rounds up
205+
(Decimal("1.235"), Decimal("1.24")), # .235 rounds up (HALF_UP)
206+
(Decimal("1.234"), Decimal("1.23")), # .234 rounds down
207+
(Decimal("1.225"), Decimal("1.23")), # .225 rounds up (HALF_UP)
208+
(Decimal("-1.235"), Decimal("-1.24")), # negative HALF_UP
209+
]
210+
for val, expected in test_cases:
211+
with self.subTest(value=val):
212+
row = GenericRow([val], fields, RowKind.INSERT)
213+
serialized = GenericRowSerializer.to_bytes(row)
214+
result = GenericRowDeserializer.from_bytes(serialized, fields)
215+
self.assertEqual(result.values[0], expected)
216+
217+
def test_decimal_precision_overflow_returns_null(self):
218+
"""Values exceeding declared precision should be stored as null."""
219+
# DECIMAL(4, 2) can hold at most 2 integer + 2 fractional digits => max 99.99
220+
fields = [DataField(0, "d", AtomicType("DECIMAL(4, 2)"))]
221+
222+
# 999.99 needs 5 digits total, exceeds precision=4
223+
row = GenericRow([Decimal("999.99")], fields, RowKind.INSERT)
224+
serialized = GenericRowSerializer.to_bytes(row)
225+
result = GenericRowDeserializer.from_bytes(serialized, fields)
226+
self.assertIsNone(result.values[0])
227+
228+
# 99.999 rounds to 100.00 (5 digits), also overflows
229+
row2 = GenericRow([Decimal("99.999")], fields, RowKind.INSERT)
230+
serialized2 = GenericRowSerializer.to_bytes(row2)
231+
result2 = GenericRowDeserializer.from_bytes(serialized2, fields)
232+
self.assertIsNone(result2.values[0])
233+
234+
# 99.99 fits exactly in DECIMAL(4, 2)
235+
row3 = GenericRow([Decimal("99.99")], fields, RowKind.INSERT)
236+
serialized3 = GenericRowSerializer.to_bytes(row3)
237+
result3 = GenericRowDeserializer.from_bytes(serialized3, fields)
238+
self.assertEqual(result3.values[0], Decimal("99.99"))
239+
240+
def test_decimal_precision_overflow_high_precision(self):
241+
"""Precision overflow check also works for non-compact decimals."""
242+
# DECIMAL(20, 5) can hold 15 integer + 5 fractional digits
243+
fields = [DataField(0, "d", AtomicType("DECIMAL(20, 5)"))]
244+
245+
# This value fits: 15 integer digits + 5 fractional
246+
row = GenericRow([Decimal("123456789012345.12345")], fields, RowKind.INSERT)
247+
serialized = GenericRowSerializer.to_bytes(row)
248+
result = GenericRowDeserializer.from_bytes(serialized, fields)
249+
self.assertEqual(result.values[0], Decimal("123456789012345.12345"))
250+
251+
# This value overflows: 16 integer digits + 5 fractional = 21 > 20
252+
row2 = GenericRow([Decimal("1234567890123456.12345")], fields, RowKind.INSERT)
253+
serialized2 = GenericRowSerializer.to_bytes(row2)
254+
result2 = GenericRowDeserializer.from_bytes(serialized2, fields)
255+
self.assertIsNone(result2.values[0])
256+
257+
def test_decimal_deserialization_precision_overflow_non_compact(self):
258+
"""Non-compact decimal deserialization returns None if precision overflows."""
259+
# Serialize with DECIMAL(38, 5) which fits, then deserialize as DECIMAL(20, 5)
260+
fields_wide = [DataField(0, "d", AtomicType("DECIMAL(38, 5)"))]
261+
fields_narrow = [DataField(0, "d", AtomicType("DECIMAL(20, 5)"))]
262+
263+
# 21 digits total exceeds precision=20
264+
row = GenericRow([Decimal("1234567890123456.12345")], fields_wide, RowKind.INSERT)
265+
serialized = GenericRowSerializer.to_bytes(row)
266+
result = GenericRowDeserializer.from_bytes(serialized, fields_narrow)
267+
self.assertIsNone(result.values[0])
268+
269+
def test_decimal_deserialization_invalid_precision(self):
270+
"""Deserialization with precision <= 0 raises ValueError."""
271+
fields_valid = [DataField(0, "d", AtomicType("DECIMAL(10, 2)"))]
272+
row = GenericRow([Decimal("1.23")], fields_valid, RowKind.INSERT)
273+
serialized = GenericRowSerializer.to_bytes(row)
274+
275+
fields_bad = [DataField(0, "d", AtomicType("DECIMAL(0, 2)"))]
276+
with self.assertRaises(ValueError):
277+
GenericRowDeserializer.from_bytes(serialized, fields_bad)
206278

207279

208280
if __name__ == '__main__':

0 commit comments

Comments
 (0)