@@ -51,7 +51,7 @@ def ndarray_to_bytes(obj):
5151 def tostr (x ):
5252 return x
5353
54- def encode (obj , chain = None ):
54+ def encode (obj , chain = None , allow_pickle = False ):
5555 """
5656 Data encoder for serializing numpy data types.
5757 """
@@ -60,6 +60,8 @@ def encode(obj, chain=None):
6060 # If the dtype is structured, store the interface description;
6161 # otherwise, store the corresponding array protocol type string:
6262 if obj .dtype .kind in ('V' , 'O' ):
63+ if obj .dtype .kind == 'O' and not allow_pickle :
64+ raise ValueError ("Can't pickle object arrays if allow_pickle is False" )
6365 kind = bytes (obj .dtype .kind , 'ascii' )
6466 descr = obj .dtype .descr
6567 else :
@@ -81,7 +83,7 @@ def encode(obj, chain=None):
8183 else :
8284 return obj if chain is None else chain (obj )
8385
84- def decode (obj , chain = None ):
86+ def decode (obj , chain = None , allow_pickle = False ):
8587 """
8688 Decoder for deserializing numpy data types.
8789 """
@@ -97,6 +99,8 @@ def decode(obj, chain=None):
9799 descr = [tuple (tostr (t ) if type (t ) is bytes else t for t in d ) \
98100 for d in obj [b'type' ]]
99101 elif b'kind' in obj and obj [b'kind' ] == b'O' :
102+ if not allow_pickle :
103+ raise ValueError ("Can't unpickle object arrays if allow_pickle is False" )
100104 return pickle .loads (obj [b'data' ])
101105 else :
102106 descr = obj [b'type' ]
@@ -138,8 +142,9 @@ def __init__(self, default=None,
138142 encoding = 'utf-8' ,
139143 unicode_errors = 'strict' ,
140144 use_single_float = False ,
141- autoreset = 1 ):
142- default = functools .partial (encode , chain = default )
145+ autoreset = 1 ,
146+ allow_pickle = False ):
147+ default = functools .partial (encode , chain = default , allow_pickle = allow_pickle )
143148 super (Packer , self ).__init__ (default = default ,
144149 encoding = encoding ,
145150 unicode_errors = unicode_errors ,
@@ -149,8 +154,8 @@ class Unpacker(_Unpacker):
149154 def __init__ (self , file_like = None , read_size = 0 , use_list = None ,
150155 object_hook = None ,
151156 object_pairs_hook = None , list_hook = None , encoding = 'utf-8' ,
152- unicode_errors = 'strict' , max_buffer_size = 0 ):
153- object_hook = functools .partial (decode , chain = object_hook )
157+ unicode_errors = 'strict' , max_buffer_size = 0 , allow_pickle = False ):
158+ object_hook = functools .partial (decode , chain = object_hook , allow_pickle = allow_pickle )
154159 super (Unpacker , self ).__init__ (file_like = file_like ,
155160 read_size = read_size ,
156161 use_list = use_list ,
@@ -168,8 +173,9 @@ def __init__(self, default=None,
168173 use_single_float = False ,
169174 autoreset = 1 ,
170175 use_bin_type = True ,
171- strict_types = False ):
172- default = functools .partial (encode , chain = default )
176+ strict_types = False ,
177+ allow_pickle = False ):
178+ default = functools .partial (encode , chain = default , allow_pickle = allow_pickle )
173179 super (Packer , self ).__init__ (default = default ,
174180 unicode_errors = unicode_errors ,
175181 use_single_float = use_single_float ,
@@ -183,8 +189,8 @@ def __init__(self, file_like=None, read_size=0, use_list=None,
183189 object_hook = None ,
184190 object_pairs_hook = None , list_hook = None ,
185191 unicode_errors = 'strict' , max_buffer_size = 0 ,
186- ext_hook = msgpack .ExtType ):
187- object_hook = functools .partial (decode , chain = object_hook )
192+ ext_hook = msgpack .ExtType , allow_pickle = False ):
193+ object_hook = functools .partial (decode , chain = object_hook , allow_pickle = allow_pickle )
188194 super (Unpacker , self ).__init__ (file_like = file_like ,
189195 read_size = read_size ,
190196 use_list = use_list ,
@@ -205,8 +211,9 @@ def __init__(self,
205211 use_bin_type = True ,
206212 strict_types = False ,
207213 datetime = False ,
208- unicode_errors = None ):
209- default = functools .partial (encode , chain = default )
214+ unicode_errors = None ,
215+ allow_pickle = False ):
216+ default = functools .partial (encode , chain = default , allow_pickle = allow_pickle )
210217 super (Packer , self ).__init__ (default = default ,
211218 use_single_float = use_single_float ,
212219 autoreset = autoreset ,
@@ -233,8 +240,9 @@ def __init__(self,
233240 max_bin_len = - 1 ,
234241 max_array_len = - 1 ,
235242 max_map_len = - 1 ,
236- max_ext_len = - 1 ):
237- object_hook = functools .partial (decode , chain = object_hook )
243+ max_ext_len = - 1 ,
244+ allow_pickle = False ):
245+ object_hook = functools .partial (decode , chain = object_hook , allow_pickle = allow_pickle )
238246 super (Unpacker , self ).__init__ (file_like = file_like ,
239247 read_size = read_size ,
240248 use_list = use_list ,
@@ -268,41 +276,53 @@ def packb(o, **kwargs):
268276
269277 return Packer (** kwargs ).pack (o )
270278
271- def unpack (stream , ** kwargs ):
279+ def unpack (stream , allow_pickle = False , ** kwargs ):
272280 """
273281 Unpack a packed object from a stream.
274282 """
275283
276284 object_hook = kwargs .get ('object_hook' )
277- kwargs ['object_hook' ] = functools .partial (decode , chain = object_hook )
285+ kwargs ['object_hook' ] = functools .partial (decode , chain = object_hook , allow_pickle = allow_pickle )
278286 return _unpack (stream , ** kwargs )
279287
280- def unpackb (packed , ** kwargs ):
288+ def unpackb (packed , allow_pickle = False , ** kwargs ):
281289 """
282290 Unpack a packed object.
283291 """
284292
285293 object_hook = kwargs .get ('object_hook' )
286- kwargs ['object_hook' ] = functools .partial (decode , chain = object_hook )
294+ kwargs ['object_hook' ] = functools .partial (decode , chain = object_hook , allow_pickle = allow_pickle )
287295 return _unpackb (packed , ** kwargs )
288296
289297load = unpack
290298loads = unpackb
291299dump = pack
292300dumps = packb
293301
294- def patch ():
302+ def patch (allow_pickle = False ):
295303 """
296304 Monkey patch msgpack module to enable support for serializing numpy types.
297305 """
298-
299- setattr (msgpack , 'Packer' , Packer )
300- setattr (msgpack , 'Unpacker' , Unpacker )
301- setattr (msgpack , 'load' , unpack )
302- setattr (msgpack , 'loads' , unpackb )
303- setattr (msgpack , 'dump' , pack )
304- setattr (msgpack , 'dumps' , packb )
305- setattr (msgpack , 'pack' , pack )
306- setattr (msgpack , 'packb' , packb )
307- setattr (msgpack , 'unpack' , unpack )
308- setattr (msgpack , 'unpackb' , unpackb )
306+ class Packer_ (Packer ):
307+ def __init__ (self , * args , ** kws ):
308+ super (Packer , self ).__init__ (* args , ** kws , allow_pickle = allow_pickle )
309+
310+ class Unpacker_ (Unpacker ):
311+ def __init__ (self , * args , ** kws ):
312+ super (Unpacker , self ).__init__ (* args , ** kws , allow_pickle = allow_pickle )
313+
314+ pack_ = functools .partial (pack , allow_pickle = allow_pickle )
315+ packb_ = functools .partial (packb , allow_pickle = allow_pickle )
316+ unpack_ = functools .partial (unpack , allow_pickle = allow_pickle )
317+ unpackb_ = functools .partial (unpackb , allow_pickle = allow_pickle )
318+
319+ setattr (msgpack , 'Packer' , Packer_ )
320+ setattr (msgpack , 'Unpacker' , Unpacker_ )
321+ setattr (msgpack , 'load' , unpack_ )
322+ setattr (msgpack , 'loads' , unpackb_ )
323+ setattr (msgpack , 'dump' , pack_ )
324+ setattr (msgpack , 'dumps' , packb_ )
325+ setattr (msgpack , 'pack' , pack_ )
326+ setattr (msgpack , 'packb' , packb_ )
327+ setattr (msgpack , 'unpack' , unpack_ )
328+ setattr (msgpack , 'unpackb' , unpackb_ )
0 commit comments