diff --git a/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/ETaccessor_post_Rcpp.h b/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/ETaccessor_post_Rcpp.h index 07dd3786..075b96a6 100644 --- a/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/ETaccessor_post_Rcpp.h +++ b/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/ETaccessor_post_Rcpp.h @@ -241,7 +241,7 @@ Scalar& ETaccessorBase::scalar() { // then specialize to allow valid types (Eigen::Tensor's or true scalars) // These are supported as run-time errors because the genericInterfaceC // will access them by a name. -template +template class ETaccessor : public ETaccessorTyped { public: using ET = Eigen::Tensor; @@ -278,8 +278,8 @@ class ETaccessor : public ETaccessorTyped { }; -template -class ETaccessor > : public ETaccessorTyped { +template +class ETaccessor, copy> : public ETaccessorTyped { public: using ET = Eigen::Tensor; // using Scalar = typename ET::Scalar; @@ -306,7 +306,24 @@ class ETaccessor > : public ETaccessorTyped std::vector intDims_; }; -template +template +struct ETaccessorCopyHolder { + ET obj_copy; + ETaccessorCopyHolder(const ET &src) : obj_copy(src) {} +}; + +template +class ETaccessor, true> : + private ETaccessorCopyHolder>, + public ETaccessor, false> { +public: + using ET = Eigen::Tensor; + using Holder = ETaccessorCopyHolder; + ETaccessor(const ET &obj_) : Holder(obj_), ETaccessor(Holder::obj_copy) {}; + ~ETaccessor() {}; +}; + +template class ETaccessorScalar : public ETaccessorTyped { public: ETaccessorScalar(Scalar &obj_) : obj(obj_) {}; @@ -323,24 +340,38 @@ class ETaccessorScalar : public ETaccessorTyped { std::vector intDims_; }; -template<> -class ETaccessor : public ETaccessorScalar { +template +class ETaccessorScalar : + private ETaccessorCopyHolder, + public ETaccessorScalar { +public: + using ET = ETaccessorScalar; + using Holder = ETaccessorCopyHolder; + ETaccessorScalar(const Scalar &obj_) : Holder(obj_), ET(Holder::obj_copy) {}; + ~ETaccessorScalar() {}; +}; + +template +class ETaccessor : public ETaccessorScalar { + using Ref = std::conditional_t; public: - ETaccessor(double &obj_) : ETaccessorScalar(obj_) {}; + ETaccessor(Ref obj_) : ETaccessorScalar(obj_) {}; ~ETaccessor() {}; }; -template<> -class ETaccessor : public ETaccessorScalar { +template +class ETaccessor : public ETaccessorScalar { + using Ref = std::conditional_t; public: - ETaccessor(int &obj_) : ETaccessorScalar(obj_) {}; + ETaccessor(Ref obj_) : ETaccessorScalar(obj_) {}; ~ETaccessor() {}; }; -template<> -class ETaccessor : public ETaccessorScalar { +template +class ETaccessor : public ETaccessorScalar { + using Ref = std::conditional_t; public: - ETaccessor(bool &obj_) : ETaccessorScalar(obj_) {}; + ETaccessor(Ref obj_) : ETaccessorScalar(obj_) {}; ~ETaccessor() {}; }; @@ -351,9 +382,16 @@ Eigen::Tensor &ETaccessorBase::ref() { return castptr->innerRef(); } -template -auto ETaccess(T &x) -> ETaccessor{ - return ETaccessor(x); +template +std::enable_if_t> +ETaccess(T &x) { + return ETaccessor(x); +} + +template +std::enable_if_t> +ETaccess(const T &x) { + return ETaccessor(x); } // end ETaccess diff --git a/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/nC_as.h b/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/nC_as.h index c8001718..0b82774b 100644 --- a/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/nC_as.h +++ b/nCompiler/inst/include/nCompiler/ET_Rcpp_ext/post_Rcpp/nC_as.h @@ -131,7 +131,7 @@ class CastingProxy { template class RuntimeCastingProxy { - // TargetType5o here should be a true scalar type, + // TargetType here should be a true scalar type, // because specialization to Eigen::Tensor types is below. typedef TargetType TargetScalar; @@ -273,9 +273,9 @@ class RuntimeCastingProxy > { // Compile-time source: delegates to ETaccessorTyped::asTyped<>(). // Returns EmptyProxy, EmptyProxy, RHSCastProxy, or CastingProxy. -template +template auto as_nC(T& x) { - return ETaccess(x).template asTyped(); + return ETaccess(x).template asTyped(); } // Runtime source: scalar type of acc is unknown at compile time. diff --git a/nCompiler/tests/testthat/cpp_tests/test-ETaccess.R b/nCompiler/tests/testthat/cpp_tests/test-ETaccess.R index da66e81e..90ee616a 100644 --- a/nCompiler/tests/testthat/cpp_tests/test-ETaccess.R +++ b/nCompiler/tests/testthat/cpp_tests/test-ETaccess.R @@ -653,3 +653,68 @@ test_that("access and ETaccess work for logicals", { expect_error (obj$LM_get_set_arg_map_scalar(LM)) expect_identical(obj$LM_get_set_arg_map_scalar(matrix(LS)), !LS) }) + +test_that("ETaccess copies data and isolates from original", { + nc <- nClass( + Cpublic = list( + # Scalar: reading from copy gives correct value + DS_copy_read = nFunction( + function(DSx = 'numericScalar') { + cppLiteral('return ETaccess(DSx).scalar();') + returnType('numericScalar') + } + ), + # Scalar: writing to copy does not affect original + DS_copy_isolation = nFunction( + function(DSx = 'numericScalar') { + cppLiteral('{ auto acc = ETaccess(DSx); acc.scalar() = 99.0; }') + cppLiteral('return ETaccess(DSx).scalar();') + returnType('numericScalar') + } + ), + # Vector: reading from copy gives correct value + DV_copy_read = nFunction( + function(DVx = 'numericVector') { + cppLiteral('ans = ETaccess(DVx).map<1>();', types = list(ans = 'numericVector')) + cppLiteral('return ans;') + returnType('numericVector') + } + ), + # Vector: writing to copy does not affect original + DV_copy_isolation = nFunction( + function(DVx = 'numericVector') { + cppLiteral('{ auto acc = ETaccess(DVx); acc.map<1>() = acc.map<1>() * 0.0; }') + cppLiteral('return ETaccess(DVx).map<1>();') + returnType('numericVector') + } + ), + # Vector: copy is constructed from a lazy expression + DV_copy_expr = nFunction( + function(DVx = 'numericVector') { + cppLiteral('ans = ETaccess>(DVx * 2.0).map<1>();', + types = list(ans = 'numericVector')) + cppLiteral('return ans;') + returnType('numericVector') + } + ) + ) + ) + Cnc <- nCompile(nc) + obj <- Cnc$new() + DS <- 1.23456 + DV <- DS * 1:3 + + # Scalar copy reads the correct initial value + expect_equal(obj$DS_copy_read(DS), DS) + # Scalar copy modification leaves the original unchanged + expect_equal(obj$DS_copy_isolation(DS), DS) + + # Vector copy reads the correct initial values + expect_equal(obj$DV_copy_read(DV), DV) + # Vector copy modification leaves the original unchanged + expect_equal(obj$DV_copy_isolation(DV), DV) + + # Lazy expression is evaluated into the copy; original is not involved + expect_equal(obj$DV_copy_expr(DV), DV * 2) + rm(obj); gc() +})