diff --git a/cassandra/query.py b/cassandra/query.py index 6c6878fdb4..9925c2a1ab 100644 --- a/cassandra/query.py +++ b/cassandra/query.py @@ -49,6 +49,9 @@ Only valid when using native protocol v4+ """ +_BIND_SENTINEL = object() +"""Sentinel for dict.get() in BoundStatement.bind() to distinguish missing keys from None values.""" + NON_ALPHA_REGEX = re.compile('[^a-zA-Z0-9]') START_BADCHAR_REGEX = re.compile('^[^a-zA-Z0-9]*') END_BADCHAR_REGEX = re.compile('[^a-zA-Z0-9_]*$') @@ -605,18 +608,28 @@ def bind(self, values): if isinstance(values, dict): values_dict = values values = [] + plain_dict = type(values_dict) is dict # sort values accordingly for col in col_meta: - try: - values.append(values_dict[col.name]) - except KeyError: - if proto_version >= 4: - values.append(UNSET_VALUE) - else: - raise KeyError( - 'Column name `%s` not found in bound dict.' % - (col.name)) + if plain_dict: + val = values_dict.get(col.name, _BIND_SENTINEL) + if val is not _BIND_SENTINEL: + values.append(val) + continue + else: + try: + values.append(values_dict[col.name]) + continue + except KeyError: + pass + + if proto_version >= 4: + values.append(UNSET_VALUE) + else: + raise KeyError( + 'Column name `%s` not found in bound dict.' % + (col.name)) value_len = len(values) col_meta_len = len(col_meta) diff --git a/tests/unit/test_parameter_binding.py b/tests/unit/test_parameter_binding.py index 5416ac461d..be5f98ace6 100644 --- a/tests/unit/test_parameter_binding.py +++ b/tests/unit/test_parameter_binding.py @@ -184,6 +184,14 @@ def test_unset_value(self): with pytest.raises(ValueError): self.bound.bind((0, 0, 0, UNSET_VALUE)) + def test_dict_subclass_missing_value(self): + class MissingDict(dict): + def __missing__(self, key): + return 0 + + self.bound.bind(MissingDict({'rk0': 0, 'rk1': 0, 'ck0': 0})) + assert self.bound.values == [b'\x00' * 4] * 4 + class BoundStatementTestV4(BoundStatementTestV3): protocol_version = 4 @@ -213,6 +221,14 @@ def test_unset_value(self): self.bound.bind((0, 0, 0, UNSET_VALUE)) assert self.bound.values[-1] == UNSET_VALUE + def test_dict_subclass_missing_value(self): + class MissingDict(dict): + def __missing__(self, key): + return 0 + + self.bound.bind(MissingDict({'rk0': 0, 'rk1': 0, 'ck0': 0})) + assert self.bound.values == [b'\x00' * 4] * 4 + class BoundStatementTestV5(BoundStatementTestV4): protocol_version = 5