Skip to content

Commit 8f0bf24

Browse files
Merge pull request #1 from opentensor/disable_pickle
Disable pickle
2 parents fd7032a + 122c589 commit 8f0bf24

2 files changed

Lines changed: 59 additions & 31 deletions

File tree

msgpack_numpy.py

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def ndarray_to_bytes(obj):
5151
def tostr(x):
5252
return x
5353

54-
def encode(obj, chain=None):
54+
def encode(obj, chain=None, allow_pickle=False):
5555
"""
5656
Data encoder for serializing numpy data types.
5757
"""
@@ -60,6 +60,8 @@ def encode(obj, chain=None):
6060
# If the dtype is structured, store the interface description;
6161
# otherwise, store the corresponding array protocol type string:
6262
if obj.dtype.kind in ('V', 'O'):
63+
if obj.dtype.kind == 'O' and not allow_pickle:
64+
raise ValueError("Can't pickle object arrays if allow_pickle is False")
6365
kind = bytes(obj.dtype.kind, 'ascii')
6466
descr = obj.dtype.descr
6567
else:
@@ -81,7 +83,7 @@ def encode(obj, chain=None):
8183
else:
8284
return obj if chain is None else chain(obj)
8385

84-
def decode(obj, chain=None):
86+
def decode(obj, chain=None, allow_pickle=False):
8587
"""
8688
Decoder for deserializing numpy data types.
8789
"""
@@ -97,6 +99,8 @@ def decode(obj, chain=None):
9799
descr = [tuple(tostr(t) if type(t) is bytes else t for t in d) \
98100
for d in obj[b'type']]
99101
elif b'kind' in obj and obj[b'kind'] == b'O':
102+
if not allow_pickle:
103+
raise ValueError("Can't unpickle object arrays if allow_pickle is False")
100104
return pickle.loads(obj[b'data'])
101105
else:
102106
descr = obj[b'type']
@@ -138,8 +142,9 @@ def __init__(self, default=None,
138142
encoding='utf-8',
139143
unicode_errors='strict',
140144
use_single_float=False,
141-
autoreset=1):
142-
default = functools.partial(encode, chain=default)
145+
autoreset=1,
146+
allow_pickle=False):
147+
default = functools.partial(encode, chain=default, allow_pickle=allow_pickle)
143148
super(Packer, self).__init__(default=default,
144149
encoding=encoding,
145150
unicode_errors=unicode_errors,
@@ -149,8 +154,8 @@ class Unpacker(_Unpacker):
149154
def __init__(self, file_like=None, read_size=0, use_list=None,
150155
object_hook=None,
151156
object_pairs_hook=None, list_hook=None, encoding='utf-8',
152-
unicode_errors='strict', max_buffer_size=0):
153-
object_hook = functools.partial(decode, chain=object_hook)
157+
unicode_errors='strict', max_buffer_size=0, allow_pickle=False):
158+
object_hook = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle)
154159
super(Unpacker, self).__init__(file_like=file_like,
155160
read_size=read_size,
156161
use_list=use_list,
@@ -168,8 +173,9 @@ def __init__(self, default=None,
168173
use_single_float=False,
169174
autoreset=1,
170175
use_bin_type=True,
171-
strict_types=False):
172-
default = functools.partial(encode, chain=default)
176+
strict_types=False,
177+
allow_pickle=False):
178+
default = functools.partial(encode, chain=default, allow_pickle=allow_pickle)
173179
super(Packer, self).__init__(default=default,
174180
unicode_errors=unicode_errors,
175181
use_single_float=use_single_float,
@@ -183,8 +189,8 @@ def __init__(self, file_like=None, read_size=0, use_list=None,
183189
object_hook=None,
184190
object_pairs_hook=None, list_hook=None,
185191
unicode_errors='strict', max_buffer_size=0,
186-
ext_hook=msgpack.ExtType):
187-
object_hook = functools.partial(decode, chain=object_hook)
192+
ext_hook=msgpack.ExtType, allow_pickle=False):
193+
object_hook = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle)
188194
super(Unpacker, self).__init__(file_like=file_like,
189195
read_size=read_size,
190196
use_list=use_list,
@@ -205,8 +211,9 @@ def __init__(self,
205211
use_bin_type=True,
206212
strict_types=False,
207213
datetime=False,
208-
unicode_errors=None):
209-
default = functools.partial(encode, chain=default)
214+
unicode_errors=None,
215+
allow_pickle=False):
216+
default = functools.partial(encode, chain=default, allow_pickle=allow_pickle)
210217
super(Packer, self).__init__(default=default,
211218
use_single_float=use_single_float,
212219
autoreset=autoreset,
@@ -233,8 +240,9 @@ def __init__(self,
233240
max_bin_len=-1,
234241
max_array_len=-1,
235242
max_map_len=-1,
236-
max_ext_len=-1):
237-
object_hook = functools.partial(decode, chain=object_hook)
243+
max_ext_len=-1,
244+
allow_pickle=False):
245+
object_hook = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle)
238246
super(Unpacker, self).__init__(file_like=file_like,
239247
read_size=read_size,
240248
use_list=use_list,
@@ -268,41 +276,53 @@ def packb(o, **kwargs):
268276

269277
return Packer(**kwargs).pack(o)
270278

271-
def unpack(stream, **kwargs):
279+
def unpack(stream, allow_pickle=False, **kwargs):
272280
"""
273281
Unpack a packed object from a stream.
274282
"""
275283

276284
object_hook = kwargs.get('object_hook')
277-
kwargs['object_hook'] = functools.partial(decode, chain=object_hook)
285+
kwargs['object_hook'] = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle)
278286
return _unpack(stream, **kwargs)
279287

280-
def unpackb(packed, **kwargs):
288+
def unpackb(packed, allow_pickle=False, **kwargs):
281289
"""
282290
Unpack a packed object.
283291
"""
284292

285293
object_hook = kwargs.get('object_hook')
286-
kwargs['object_hook'] = functools.partial(decode, chain=object_hook)
294+
kwargs['object_hook'] = functools.partial(decode, chain=object_hook, allow_pickle=allow_pickle)
287295
return _unpackb(packed, **kwargs)
288296

289297
load = unpack
290298
loads = unpackb
291299
dump = pack
292300
dumps = packb
293301

294-
def patch():
302+
def patch(allow_pickle=False):
295303
"""
296304
Monkey patch msgpack module to enable support for serializing numpy types.
297305
"""
298-
299-
setattr(msgpack, 'Packer', Packer)
300-
setattr(msgpack, 'Unpacker', Unpacker)
301-
setattr(msgpack, 'load', unpack)
302-
setattr(msgpack, 'loads', unpackb)
303-
setattr(msgpack, 'dump', pack)
304-
setattr(msgpack, 'dumps', packb)
305-
setattr(msgpack, 'pack', pack)
306-
setattr(msgpack, 'packb', packb)
307-
setattr(msgpack, 'unpack', unpack)
308-
setattr(msgpack, 'unpackb', unpackb)
306+
class Packer_(Packer):
307+
def __init__(self, *args, **kws):
308+
super(Packer, self).__init__(*args, **kws, allow_pickle=allow_pickle)
309+
310+
class Unpacker_(Unpacker):
311+
def __init__(self, *args, **kws):
312+
super(Unpacker, self).__init__(*args, **kws, allow_pickle=allow_pickle)
313+
314+
pack_ = functools.partial(pack, allow_pickle=allow_pickle)
315+
packb_ = functools.partial(packb, allow_pickle=allow_pickle)
316+
unpack_ = functools.partial(unpack, allow_pickle=allow_pickle)
317+
unpackb_ = functools.partial(unpackb, allow_pickle=allow_pickle)
318+
319+
setattr(msgpack, 'Packer', Packer_)
320+
setattr(msgpack, 'Unpacker', Unpacker_)
321+
setattr(msgpack, 'load', unpack_)
322+
setattr(msgpack, 'loads', unpackb_)
323+
setattr(msgpack, 'dump', pack_)
324+
setattr(msgpack, 'dumps', packb_)
325+
setattr(msgpack, 'pack', pack_)
326+
setattr(msgpack, 'packb', packb_)
327+
setattr(msgpack, 'unpack', unpack_)
328+
setattr(msgpack, 'unpackb', unpackb_)

tests.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __eq__(self, other):
2424

2525
class test_numpy_msgpack(TestCase):
2626
def setUp(self):
27-
patch()
27+
patch(allow_pickle=True)
2828

2929
def encode_decode(self, x, use_list=True, max_bin_len=-1):
3030
x_enc = msgpack.packb(x)
@@ -288,5 +288,13 @@ def test_numpy_nested_structured_array(self):
288288
assert_array_equal(x, x_rec)
289289
self.assertEqual(x.dtype, x_rec.dtype)
290290

291+
class test_numpy_msgpack_no_pickle(test_numpy_msgpack):
292+
def setUp(self):
293+
patch(allow_pickle=False)
294+
295+
def test_numpy_array_object(self):
296+
x = np.random.rand(5).astype(object)
297+
self.assertRaises(ValueError, self.encode_decode, x)
298+
291299
if __name__ == '__main__':
292300
main()

0 commit comments

Comments
 (0)