diff --git a/google/cloud/spanner_v1/data_types.py b/google/cloud/spanner_v1/data_types.py index 6703f359e9..07e4faa1c1 100644 --- a/google/cloud/spanner_v1/data_types.py +++ b/google/cloud/spanner_v1/data_types.py @@ -56,6 +56,56 @@ def __init__(self, *args, **kwargs): if not self._is_null: super(JsonObject, self).__init__(*args, **kwargs) + def __len__(self): + if self._is_null: + return 0 + if self._is_array: + return len(self._array_value) + if self._is_scalar_value: + return 1 + return super(JsonObject, self).__len__() + + def __bool__(self): + if self._is_null: + return False + if self._is_array: + return bool(self._array_value) + if self._is_scalar_value: + return True + return len(self) > 0 + + def __iter__(self): + if self._is_array: + return iter(self._array_value) + if self._is_scalar_value: + raise TypeError(f"'{type(self._simple_value).__name__}' object is not iterable") + return super(JsonObject, self).__iter__() + + def __getitem__(self, key): + if self._is_array: + return self._array_value[key] + if self._is_scalar_value: + raise TypeError(f"'{type(self._simple_value).__name__}' object is not subscriptable") + return super(JsonObject, self).__getitem__(key) + + def __contains__(self, item): + if self._is_array: + return item in self._array_value + if self._is_scalar_value: + raise TypeError(f"argument of type '{type(self._simple_value).__name__}' is not iterable") + return super(JsonObject, self).__contains__(item) + + def __eq__(self, other): + if isinstance(other, JsonObject): + return self.serialize() == other.serialize() + if self._is_array: + return self._array_value == other + if self._is_scalar_value: + return self._simple_value == other + if self._is_null: + return other is None or (isinstance(other, dict) and len(other) == 0) + return super(JsonObject, self).__eq__(other) + def __repr__(self): if self._is_array: return str(self._array_value) diff --git a/tests/unit/test_datatypes.py b/tests/unit/test_datatypes.py index 65ccacb4ff..25cc834808 100644 --- a/tests/unit/test_datatypes.py +++ b/tests/unit/test_datatypes.py @@ -96,3 +96,90 @@ def test_w_JsonObject_of_list_of_simple_JsonData(self): expected = json.dumps(data, sort_keys=True, separators=(",", ":")) data_jsonobject = JsonObject(JsonObject(data)) self.assertEqual(data_jsonobject.serialize(), expected) + + +class Test_JsonObject_dict_protocol(unittest.TestCase): + """Verify that JsonObject behaves correctly with standard Python + operations (len, bool, iteration, indexing) for all JSON variants.""" + + def test_array_len(self): + obj = JsonObject([{"id": 1}, {"id": 2}]) + self.assertEqual(len(obj), 2) + + def test_array_bool_truthy(self): + obj = JsonObject([{"id": 1}]) + self.assertTrue(obj) + + def test_array_bool_empty(self): + obj = JsonObject([]) + self.assertFalse(obj) + + def test_array_iter(self): + data = [{"a": 1}, {"b": 2}] + obj = JsonObject(data) + self.assertEqual(list(obj), data) + + def test_array_getitem(self): + data = [{"a": 1}, {"b": 2}] + obj = JsonObject(data) + self.assertEqual(obj[0], {"a": 1}) + self.assertEqual(obj[1], {"b": 2}) + + def test_array_contains(self): + data = [1, 2, 3] + obj = JsonObject(data) + self.assertIn(2, obj) + self.assertNotIn(4, obj) + + def test_array_eq(self): + data = [{"id": 1}] + obj = JsonObject(data) + self.assertEqual(obj, data) + + def test_array_json_dumps(self): + data = [{"id": "m1", "content": "hello"}] + obj = JsonObject(data) + result = json.loads(json.dumps(list(obj))) + self.assertEqual(result, data) + + def test_dict_len(self): + obj = JsonObject({"a": 1, "b": 2}) + self.assertEqual(len(obj), 2) + + def test_dict_bool(self): + obj = JsonObject({"a": 1}) + self.assertTrue(obj) + + def test_dict_iter(self): + obj = JsonObject({"a": 1, "b": 2}) + self.assertEqual(sorted(obj), ["a", "b"]) + + def test_dict_getitem(self): + obj = JsonObject({"key": "value"}) + self.assertEqual(obj["key"], "value") + + def test_null_len(self): + obj = JsonObject(None) + self.assertEqual(len(obj), 0) + + def test_null_bool(self): + obj = JsonObject(None) + self.assertFalse(obj) + + def test_scalar_len(self): + obj = JsonObject(42) + self.assertEqual(len(obj), 1) + + def test_scalar_bool(self): + obj = JsonObject(42) + self.assertTrue(obj) + + def test_scalar_not_iterable(self): + obj = JsonObject(42) + with self.assertRaises(TypeError): + iter(obj) + + def test_scalar_not_subscriptable(self): + obj = JsonObject(42) + with self.assertRaises(TypeError): + obj[0]