4444import tree
4545
4646
47- def _get_fake_data_components (decoders , features ):
47+ def _get_fake_data_components (decoders , features , deserialize_method = None ):
4848 """Gets all the components to generate fake data in the tests.
4949
5050 Args:
5151 decoders: The decoders to override, or `None` if no decoding is used.
5252 features: The original features.
53+ deserialize_method: The deserialize method to use.
5354
5455 Returns:
5556 A tuple with the data generator class, the features, the feature specs and
@@ -76,6 +77,28 @@ def _get_fake_data_components(decoders, features):
7677 generator_cls = RandomFakeGenerator
7778 specs = features .get_tensor_info ()
7879 decode_fn = lambda ex : ex # identity
80+ del deserialize_method
81+ # if deserialize_method == decode.DeserializeMethod.RAW_BYTES:
82+ # generator_cls = SerializedRandomFakeGenerator
83+ # specs = features.get_serialized_info()
84+ # decode_fn = lambda ex: ex # identity
85+ # else:
86+ # has_nested_dataset = any(
87+ # isinstance(f, features_lib.Dataset) for f in features._flatten(features) # pylint: disable=protected-access
88+ # )
89+ # if (
90+ # decoders is not None
91+ # or has_nested_dataset
92+ # or deserialize_method == decode.DeserializeMethod.DESERIALIZE_NO_DECOD
93+ # ):
94+ # # If a decoder is passed, encode/decode the examples.
95+ # generator_cls = EncodedRandomFakeGenerator
96+ # specs = features.get_serialized_info()
97+ # decode_fn = functools.partial(features.decode_example, decoders=decoders
98+ # else:
99+ # generator_cls = RandomFakeGenerator
100+ # specs = features.get_tensor_info()
101+ # decode_fn = lambda ex: ex # identity
79102 return generator_cls , features , specs , decode_fn
80103
81104
@@ -352,7 +375,9 @@ def mock_as_dataset(self, split, decoders=None, read_config=None, **kwargs):
352375
353376 return ds
354377
355- def mock_as_data_source (self , split , decoders = None , ** kwargs ):
378+ def mock_as_data_source (
379+ self , split , decoders = None , deserialize_method = None , ** kwargs
380+ ):
356381 """Mocks `builder.as_data_source`."""
357382 del kwargs
358383 nonlocal mock_array_record_data_source
@@ -362,7 +387,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
362387 split = {s : s for s in self .info .splits }
363388
364389 generator_cls , features , _ , _ = _get_fake_data_components (
365- decoders , self .info .features
390+ decoders , self .info .features , deserialize_method = deserialize_method
366391 )
367392 generator = generator_cls (features , num_examples )
368393
0 commit comments