1919
2020from numcodecs .abc import Codec
2121
22+ from ._chunked import ChunkedNdArray
2223from .abc import CodecCombinatorMixin
2324
2425
@@ -89,11 +90,22 @@ def encode(self, buf: Buffer) -> Buffer:
8990 protocol.
9091 """
9192
93+ if len (self ) == 0 :
94+ return buf
95+
96+ chunked = getattr (buf , "chunked" , False )
97+
9298 encoded = buf
9399 for codec in self :
94- encoded = codec . encode (
100+ encoded_ndarray = np . asarray (
95101 numcodecs .compat .ensure_contiguous_ndarray_like (encoded , flatten = False )
96102 )
103+ encoded = codec .encode (
104+ ChunkedNdArray (encoded_ndarray ) if chunked else encoded_ndarray
105+ )
106+
107+ if getattr (encoded , "chunked" , False ):
108+ return np .array (encoded ).view (np .ndarray ) # type: ignore
97109 return encoded
98110
99111 def decode (self , buf : Buffer , out : Optional [Buffer ] = None ) -> Buffer :
@@ -151,7 +163,7 @@ def encode_decode(self, buf: Buffer) -> Buffer:
151163 silhouettes .append ((encoded .shape , encoded .dtype ))
152164 encoded = np .asarray (
153165 numcodecs .compat .ensure_contiguous_ndarray_like (
154- codec .encode (_MaybeChunkedNdArray (encoded ) if chunked else encoded ),
166+ codec .encode (ChunkedNdArray (encoded ) if chunked else encoded ),
155167 flatten = False ,
156168 )
157169 )
@@ -162,12 +174,13 @@ def encode_decode(self, buf: Buffer) -> Buffer:
162174 shape , dtype = silhouettes .pop ()
163175 out = np .empty (shape = shape , dtype = dtype )
164176 decoded = (
165- codec .decode (decoded , _MaybeChunkedNdArray (out ) if chunked else out )
177+ codec .decode (decoded , ChunkedNdArray (out ) if chunked else out )
166178 .view (dtype )
167179 .reshape (shape )
168180 )
169181
170- decoded = decoded .view (np .ndarray )
182+ if getattr (decoded , "chunked" , False ):
183+ decoded = decoded .view (np .ndarray )
171184
172185 if isinstance (decoded , type (buf )):
173186 return decoded
@@ -205,7 +218,8 @@ def encode_decode_data_array(self, da: "xr.DataArray") -> "xr.DataArray":
205218
206219 import xarray as xr
207220
208- chunked = da .chunks is not None
221+ if da .chunks is None :
222+ return da .copy (data = self .encode_decode (da .values )) # type: ignore
209223
210224 def encode_decode_data_array_single_chunk (
211225 da : xr .DataArray ,
@@ -217,7 +231,7 @@ def encode_decode_data_array_single_chunk(
217231 return da .copy (deep = False ).chunk (single_chunk )
218232
219233 # eagerly compute the input chunk and encode and decode it
220- decoded = self .encode_decode (_MaybeChunkedNdArray (da .values , chunked )) # type: ignore
234+ decoded = self .encode_decode (ChunkedNdArray (da .values )) # type: ignore
221235
222236 return da .copy (deep = False , data = np .array (decoded ).view (np .ndarray )).chunk (
223237 single_chunk
@@ -307,22 +321,3 @@ def __rmul__(self, other) -> "CodecStack":
307321
308322
309323numcodecs .registry .register_codec (CodecStack )
310-
311-
312- class _MaybeChunkedNdArray (np .ndarray ):
313- __slots__ = ("_chunked" ,)
314- _chunked : bool
315-
316- def __new__ (cls , array , chunked : bool = True ):
317- obj = np .asarray (array ).view (cls )
318- obj ._chunked = chunked
319- return obj
320-
321- def __array_finalize__ (self , obj ):
322- if obj is None :
323- return
324- self ._chunked = getattr (obj , "chunked" , True )
325-
326- @property
327- def chunked (self ) -> bool :
328- return self ._chunked
0 commit comments