@@ -119,19 +119,6 @@ class Dat(petsc_Dat):
119119 """
120120 Dat for GPU.
121121 """
122- @validate_type (('dataset' , (base .DataCarrier , DataSet , Set ), DataSetTypeError ),
123- ('name' , str , NameTypeError ))
124- @validate_dtype (('dtype' , None , DataTypeError ))
125- def __init__ (self , dataset , data = None , dtype = None , name = None , uid = None ):
126-
127- if isinstance (dataset , petsc_Dat ) and not isinstance (dataset , Dat ):
128- self .__init__ (dataset .dataset , None , dtype = dataset .dtype ,
129- name = "copy_of_%s" % dataset .name )
130- self ._data [...] = dataset .data
131- return
132-
133- super (Dat , self ).__init__ (dataset , data , dtype , name , uid )
134-
135122 @cached_property
136123 def _vec (self ):
137124 assert self .dtype == PETSc .ScalarType , \
@@ -149,47 +136,6 @@ def _vec(self):
149136
150137 return cuda_vec
151138
152- @cached_property
153- def device_handle (self ):
154- if self .dtype == PETSc .ScalarType :
155- with self .vec as v :
156- return v .getCUDAHandle ()
157- elif self .dtype == PETSc .IntType :
158- m_gpu = cuda .mem_alloc (int (self ._data .nbytes ))
159- cuda .memcpy_htod (m_gpu , self ._data )
160- return m_gpu
161- else :
162- raise NotImplementedError ("Unknown type: %s." % self .dtype )
163-
164- @cached_property
165- def _kernel_args_ (self ):
166- return (self .device_handle , )
167-
168- @collective
169- @property
170- def data (self ):
171-
172- with self .vec as v :
173- v .restoreCUDAHandle (self .device_handle )
174- return v .array
175-
176- ## TODO: fail when trying to acess elems from data_ro
177- @collective
178- @property
179- def data_ro (self ):
180- with self .vec_ro as v :
181- v .restoreCUDAHandle (self .device_handle )
182- return v .array
183-
184-
185- def move_to_host (self ):
186- with self .vec_ro as v :
187- v .restoreCUDAHandle (self .device_handle )
188- self ._data = v .array
189-
190- return petsc_Dat (self ._dataset , self ._data , self .dtype , 'copy_of_%s' %
191- self .name )
192-
193139
194140class Global (petsc_Global ):
195141
@@ -226,6 +172,7 @@ def __init__(self, kernel, iterset, *args, **kwargs):
226172 otherwise they (and the :class:`~.Dat`\s, :class:`~.Map`\s
227173 and :class:`~.Mat`\s they reference) will never be collected.
228174 """
175+
229176 # Return early if we were in the cache.
230177 if self ._initialized :
231178 return
@@ -326,6 +273,12 @@ def ith_added_global_arg_i(self, i):
326273
327274 @collective
328275 def __call__ (self , * args ):
276+ #FIXME: Should prolly get rid of this once the implementation is
277+ # finalized.
278+ from pyop2 .op2 import device
279+ import pyop2 .gpu .cuda
280+ assert device .target == pyop2 .gpu .cuda
281+
329282 if self ._initialized :
330283 grid , block = self .grid_size (args [0 ], args [1 ])
331284 extra_global_args = self .get_args_marked_for_globals
@@ -417,13 +370,11 @@ def argtypes(self):
417370 argtypes = (index_type , index_type )
418371 argtypes += self ._iterset ._argtypes_
419372 for arg in self ._args :
420- assert isinstance (arg .data , Dat )
421373 argtypes += arg ._argtypes_
422374 seen = set ()
423375 for arg in self ._args :
424376 maps = arg .map_tuple
425377 for map_ in maps :
426- assert isinstance (map_ , Map )
427378 for k , t in zip (map_ ._kernel_args_ , map_ ._argtypes_ ):
428379 if k in seen :
429380 continue
0 commit comments