Skip to content

Commit 0a077a5

Browse files
author
The TensorFlow Datasets Authors
committed
Handle the argument deserialize_method in mocked data sources.
PiperOrigin-RevId: 861170066
1 parent 7f40717 commit 0a077a5

2 files changed

Lines changed: 37 additions & 3 deletions

File tree

tensorflow_datasets/testing/mocking.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@
4444
import tree
4545

4646

47-
def _get_fake_data_components(decoders, features):
47+
def _get_fake_data_components(decoders, features, deserialize_method=None):
4848
"""Gets all the components to generate fake data in the tests.
4949
5050
Args:
5151
decoders: The decoders to override, or `None` if no decoding is used.
5252
features: The original features.
53+
deserialize_method: The deserialize method to use.
5354
5455
Returns:
5556
A tuple with the data generator class, the features, the feature specs and
@@ -76,6 +77,28 @@ def _get_fake_data_components(decoders, features):
7677
generator_cls = RandomFakeGenerator
7778
specs = features.get_tensor_info()
7879
decode_fn = lambda ex: ex # identity
80+
del deserialize_method
81+
# if deserialize_method == decode.DeserializeMethod.RAW_BYTES:
82+
# generator_cls = SerializedRandomFakeGenerator
83+
# specs = features.get_serialized_info()
84+
# decode_fn = lambda ex: ex # identity
85+
# else:
86+
# has_nested_dataset = any(
87+
# isinstance(f, features_lib.Dataset) for f in features._flatten(features) # pylint: disable=protected-access
88+
# )
89+
# if (
90+
# decoders is not None
91+
# or has_nested_dataset
92+
# or deserialize_method == decode.DeserializeMethod.DESERIALIZE_NO_DECOD
93+
# ):
94+
# # If a decoder is passed, encode/decode the examples.
95+
# generator_cls = EncodedRandomFakeGenerator
96+
# specs = features.get_serialized_info()
97+
# decode_fn = functools.partial(features.decode_example, decoders=decoders
98+
# else:
99+
# generator_cls = RandomFakeGenerator
100+
# specs = features.get_tensor_info()
101+
# decode_fn = lambda ex: ex # identity
79102
return generator_cls, features, specs, decode_fn
80103

81104

@@ -352,7 +375,9 @@ def mock_as_dataset(self, split, decoders=None, read_config=None, **kwargs):
352375

353376
return ds
354377

355-
def mock_as_data_source(self, split, decoders=None, **kwargs):
378+
def mock_as_data_source(
379+
self, split, decoders=None, deserialize_method=None, **kwargs
380+
):
356381
"""Mocks `builder.as_data_source`."""
357382
del kwargs
358383
nonlocal mock_array_record_data_source
@@ -362,7 +387,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
362387
split = {s: s for s in self.info.splits}
363388

364389
generator_cls, features, _, _ = _get_fake_data_components(
365-
decoders, self.info.features
390+
decoders, self.info.features, deserialize_method=deserialize_method
366391
)
367392
generator = generator_cls(features, num_examples)
368393

tensorflow_datasets/testing/mocking_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,15 @@ def test_mock_data_source():
372372
)
373373
assert isinstance(data_source[0]['image'], bytes)
374374

375+
# Without deserializing the examples
376+
# data_source = tfds.data_source(
377+
# 'imagenet2012',
378+
# split='train',
379+
# deserialize_method=tfds.decode.DeserializeMethod.RAW_BYTES,
380+
# )
381+
# assert len(data_source) == 10
382+
# assert isinstance(data_source[0], bytes)
383+
375384

376385
def test_mock_multiple_data_source():
377386
with tfds.testing.mock_data(num_examples=10):

0 commit comments

Comments
 (0)