Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ Scalar& ETaccessorBase::scalar() {
// then specialize to allow valid types (Eigen::Tensor's or true scalars)
// These are supported as run-time errors because the genericInterfaceC
// will access them by a name.
template<typename ERROR>
template<typename ERROR, bool copy=false>
class ETaccessor : public ETaccessorTyped<double> {
public:
using ET = Eigen::Tensor<double, 0>;
Expand Down Expand Up @@ -278,8 +278,8 @@ class ETaccessor : public ETaccessorTyped<double> {
};


template<typename Scalar, int nDim>
class ETaccessor<Eigen::Tensor<Scalar, nDim> > : public ETaccessorTyped<Scalar> {
template<typename Scalar, int nDim, bool copy>
class ETaccessor<Eigen::Tensor<Scalar, nDim>, copy> : public ETaccessorTyped<Scalar> {
public:
using ET = Eigen::Tensor<Scalar, nDim>;
// using Scalar = typename ET::Scalar;
Expand All @@ -306,7 +306,24 @@ class ETaccessor<Eigen::Tensor<Scalar, nDim> > : public ETaccessorTyped<Scalar>
std::vector<int> intDims_;
};

template<typename Scalar>
template<typename ET>
struct ETaccessorCopyHolder {
ET obj_copy;
ETaccessorCopyHolder(const ET &src) : obj_copy(src) {}
};

template<typename Scalar, int nDim>
class ETaccessor<Eigen::Tensor<Scalar, nDim>, true> :
private ETaccessorCopyHolder<Eigen::Tensor<Scalar, nDim>>,
public ETaccessor<Eigen::Tensor<Scalar, nDim>, false> {
public:
using ET = Eigen::Tensor<Scalar, nDim>;
using Holder = ETaccessorCopyHolder<ET>;
ETaccessor(const ET &obj_) : Holder(obj_), ETaccessor<ET, false>(Holder::obj_copy) {};
~ETaccessor() {};
};

template<typename Scalar, bool copy=false>
class ETaccessorScalar : public ETaccessorTyped<Scalar> {
public:
ETaccessorScalar(Scalar &obj_) : obj(obj_) {};
Expand All @@ -323,24 +340,38 @@ class ETaccessorScalar : public ETaccessorTyped<Scalar> {
std::vector<int> intDims_;
};

template<>
class ETaccessor<double> : public ETaccessorScalar<double> {
template<typename Scalar>
class ETaccessorScalar<Scalar, true> :
private ETaccessorCopyHolder<Scalar>,
public ETaccessorScalar<Scalar, false> {
public:
using ET = ETaccessorScalar<Scalar, false>;
using Holder = ETaccessorCopyHolder<Scalar>;
ETaccessorScalar(const Scalar &obj_) : Holder(obj_), ET(Holder::obj_copy) {};
~ETaccessorScalar() {};
};

template<bool copy>
class ETaccessor<double, copy> : public ETaccessorScalar<double, copy> {
using Ref = std::conditional_t<copy, const double&, double&>;
public:
ETaccessor(double &obj_) : ETaccessorScalar(obj_) {};
ETaccessor(Ref obj_) : ETaccessorScalar<double, copy>(obj_) {};
~ETaccessor() {};
};

template<>
class ETaccessor<int> : public ETaccessorScalar<int> {
template<bool copy>
class ETaccessor<int, copy> : public ETaccessorScalar<int, copy> {
using Ref = std::conditional_t<copy, const int&, int&>;
public:
ETaccessor(int &obj_) : ETaccessorScalar(obj_) {};
ETaccessor(Ref obj_) : ETaccessorScalar<int, copy>(obj_) {};
~ETaccessor() {};
};

template<>
class ETaccessor<bool> : public ETaccessorScalar<bool> {
template<bool copy>
class ETaccessor<bool, copy> : public ETaccessorScalar<bool, copy> {
using Ref = std::conditional_t<copy, const bool&, bool&>;
public:
ETaccessor(bool &obj_) : ETaccessorScalar(obj_) {};
ETaccessor(Ref obj_) : ETaccessorScalar<bool, copy>(obj_) {};
~ETaccessor() {};
};

Expand All @@ -351,9 +382,16 @@ Eigen::Tensor<Scalar, nDim> &ETaccessorBase::ref() {
return castptr->innerRef();
}

template<typename T>
auto ETaccess(T &x) -> ETaccessor<T>{
return ETaccessor<T>(x);
template<bool copy=false, typename T>
std::enable_if_t<!copy, ETaccessor<T, false>>
ETaccess(T &x) {
return ETaccessor<T, false>(x);
}

template<bool copy, typename T>
std::enable_if_t<copy, ETaccessor<T, true>>
ETaccess(const T &x) {
return ETaccessor<T, true>(x);
}

// end ETaccess
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ class CastingProxy {

template<typename TargetType>
class RuntimeCastingProxy {
// TargetType5o here should be a true scalar type,
// TargetType here should be a true scalar type,
// because specialization to Eigen::Tensor types is below.
typedef TargetType TargetScalar;

Expand Down Expand Up @@ -273,9 +273,9 @@ class RuntimeCastingProxy<Eigen::Tensor<TargetScalar, nDim> > {

// Compile-time source: delegates to ETaccessorTyped<Scalar>::asTyped<>().
// Returns EmptyProxy<TM>, EmptyProxy<STM>, RHSCastProxy, or CastingProxy.
template<typename TargetType, AsMode mode = AsMode::TM, typename T>
template<typename TargetType, AsMode mode = AsMode::TM, bool copy = false, typename T>
auto as_nC(T& x) {
return ETaccess(x).template asTyped<TargetType, mode>();
return ETaccess<copy>(x).template asTyped<TargetType, mode>();
}

// Runtime source: scalar type of acc is unknown at compile time.
Expand Down
65 changes: 65 additions & 0 deletions nCompiler/tests/testthat/cpp_tests/test-ETaccess.R
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,68 @@ test_that("access and ETaccess work for logicals", {
expect_error (obj$LM_get_set_arg_map_scalar(LM))
expect_identical(obj$LM_get_set_arg_map_scalar(matrix(LS)), !LS)
})

test_that("ETaccess<true> copies data and isolates from original", {
nc <- nClass(
Cpublic = list(
# Scalar: reading from copy gives correct value
DS_copy_read = nFunction(
function(DSx = 'numericScalar') {
cppLiteral('return ETaccess<true>(DSx).scalar();')
returnType('numericScalar')
}
),
# Scalar: writing to copy does not affect original
DS_copy_isolation = nFunction(
function(DSx = 'numericScalar') {
cppLiteral('{ auto acc = ETaccess<true>(DSx); acc.scalar() = 99.0; }')
cppLiteral('return ETaccess(DSx).scalar();')
returnType('numericScalar')
}
),
# Vector: reading from copy gives correct value
DV_copy_read = nFunction(
function(DVx = 'numericVector') {
cppLiteral('ans = ETaccess<true>(DVx).map<1>();', types = list(ans = 'numericVector'))
cppLiteral('return ans;')
returnType('numericVector')
}
),
# Vector: writing to copy does not affect original
DV_copy_isolation = nFunction(
function(DVx = 'numericVector') {
cppLiteral('{ auto acc = ETaccess<true>(DVx); acc.map<1>() = acc.map<1>() * 0.0; }')
cppLiteral('return ETaccess(DVx).map<1>();')
returnType('numericVector')
}
),
# Vector: copy is constructed from a lazy expression
DV_copy_expr = nFunction(
function(DVx = 'numericVector') {
cppLiteral('ans = ETaccess<true, Eigen::Tensor<double,1>>(DVx * 2.0).map<1>();',
types = list(ans = 'numericVector'))
cppLiteral('return ans;')
returnType('numericVector')
}
)
)
)
Cnc <- nCompile(nc)
obj <- Cnc$new()
DS <- 1.23456
DV <- DS * 1:3

# Scalar copy reads the correct initial value
expect_equal(obj$DS_copy_read(DS), DS)
# Scalar copy modification leaves the original unchanged
expect_equal(obj$DS_copy_isolation(DS), DS)

# Vector copy reads the correct initial values
expect_equal(obj$DV_copy_read(DV), DV)
# Vector copy modification leaves the original unchanged
expect_equal(obj$DV_copy_isolation(DV), DV)

# Lazy expression is evaluated into the copy; original is not involved
expect_equal(obj$DV_copy_expr(DV), DV * 2)
rm(obj); gc()
})
Loading