Skip to content

Commit 68dbf6c

Browse files
committed
fix: preserve dict subclass semantics in bind fast path
1 parent 8f2bc6a commit 68dbf6c

2 files changed

Lines changed: 30 additions & 4 deletions

File tree

cassandra/query.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -608,13 +608,23 @@ def bind(self, values):
608608
if isinstance(values, dict):
609609
values_dict = values
610610
values = []
611+
plain_dict = type(values_dict) is dict
611612

612613
# sort values accordingly
613614
for col in col_meta:
614-
val = values_dict.get(col.name, _BIND_SENTINEL)
615-
if val is not _BIND_SENTINEL:
616-
values.append(val)
617-
elif proto_version >= 4:
615+
if plain_dict:
616+
val = values_dict.get(col.name, _BIND_SENTINEL)
617+
if val is not _BIND_SENTINEL:
618+
values.append(val)
619+
continue
620+
else:
621+
try:
622+
values.append(values_dict[col.name])
623+
continue
624+
except KeyError:
625+
pass
626+
627+
if proto_version >= 4:
618628
values.append(UNSET_VALUE)
619629
else:
620630
raise KeyError(

tests/unit/test_parameter_binding.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,14 @@ def test_unset_value(self):
184184
with pytest.raises(ValueError):
185185
self.bound.bind((0, 0, 0, UNSET_VALUE))
186186

187+
def test_dict_subclass_missing_value(self):
188+
class MissingDict(dict):
189+
def __missing__(self, key):
190+
return 0
191+
192+
self.bound.bind(MissingDict({'rk0': 0, 'rk1': 0, 'ck0': 0}))
193+
assert self.bound.values == [b'\x00' * 4] * 4
194+
187195

188196
class BoundStatementTestV4(BoundStatementTestV3):
189197
protocol_version = 4
@@ -213,6 +221,14 @@ def test_unset_value(self):
213221
self.bound.bind((0, 0, 0, UNSET_VALUE))
214222
assert self.bound.values[-1] == UNSET_VALUE
215223

224+
def test_dict_subclass_missing_value(self):
225+
class MissingDict(dict):
226+
def __missing__(self, key):
227+
return 0
228+
229+
self.bound.bind(MissingDict({'rk0': 0, 'rk1': 0, 'ck0': 0}))
230+
assert self.bound.values == [b'\x00' * 4] * 4
231+
216232

217233
class BoundStatementTestV5(BoundStatementTestV4):
218234
protocol_version = 5

0 commit comments

Comments
 (0)