@@ -28,6 +28,9 @@ namespace tensorflow {
2828
2929typedef Eigen::ThreadPoolDevice CPUDevice;
3030typedef Eigen::GpuDevice GPUDevice;
31+ #ifdef TENSORFLOW_USE_SYCL
32+ typedef Eigen::SyclDevice SYCLDevice;
33+ #endif // TENSORFLOW_USE_SYCL
3134
3235template <typename Device, typename T>
3336class SoftmaxXentWithLogitsOp : public OpKernel {
@@ -74,17 +77,25 @@ class SoftmaxXentWithLogitsOp : public OpKernel {
7477// Partial specialization for a CPUDevice, that uses the Eigen implementation
7578// from XentEigenImpl.
7679namespace functor {
77- template <typename T>
78- struct XentFunctor <CPUDevice, T> {
79- void operator ()(const CPUDevice & d, typename TTypes<T>::ConstMatrix logits,
80+ template <typename Device, typename T>
81+ struct XentFunctorBase {
82+ void operator ()(const Device & d, typename TTypes<T>::ConstMatrix logits,
8083 typename TTypes<T>::ConstMatrix labels,
8184 typename TTypes<T>::Matrix scratch,
8285 typename TTypes<T>::Vec loss,
8386 typename TTypes<T>::Matrix backprop) {
84- XentEigenImpl<CPUDevice , T>::Compute (d, logits, labels, scratch, loss,
87+ XentEigenImpl<Device , T>::Compute (d, logits, labels, scratch, loss,
8588 backprop);
8689 }
8790};
91+
92+ template <typename T>
93+ struct XentFunctor <CPUDevice, T> : XentFunctorBase<CPUDevice, T> {};
94+
95+ #ifdef TENSORFLOW_USE_SYCL
96+ template <typename T>
97+ struct XentFunctor <SYCLDevice, T> : XentFunctorBase<SYCLDevice, T> {};
98+ #endif // TENSORFLOW_USE_SYCL
8899} // namespace functor
89100
90101#define REGISTER_CPU (T ) \
@@ -111,4 +122,11 @@ REGISTER_KERNEL_BUILDER(Name("SoftmaxCrossEntropyWithLogits")
111122 SoftmaxXentWithLogitsOp<GPUDevice, double>);
112123#endif // GOOGLE_CUDA
113124
125+ #ifdef TENSORFLOW_USE_SYCL
126+ REGISTER_KERNEL_BUILDER (Name(" SoftmaxCrossEntropyWithLogits" )
127+ .Device(DEVICE_SYCL)
128+ .TypeConstraint<float>(" T" ),
129+ SoftmaxXentWithLogitsOp<SYCLDevice, float>);
130+ #endif // TENSORFLOW_USE_SYCL
131+
114132} // namespace tensorflow
0 commit comments