diff --git a/src/base/add.h b/src/base/add.h index 06bfa4ccd..eb4b31f4a 100644 --- a/src/base/add.h +++ b/src/base/add.h @@ -36,6 +36,12 @@ class Add : public Operator { virtual void operator()(const Tensor input, const Tensor other, Tensor out) const = 0; + template + static auto MakeReturnValue(const TensorLike& input, + const TensorLike& other) { + return TensorLike::Empty(input.shape(), input.dtype(), input.device()); + } + protected: Tensor::Size ndim_{0}; diff --git a/src/base/gemm.h b/src/base/gemm.h index 0bb350283..c0b0fdc35 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -55,6 +55,13 @@ class Gemm : public Operator { return operator()(a, b, alpha, beta, std::nullopt, std::nullopt, c); } + template + static auto MakeReturnValue(const TensorLike& a, const TensorLike& b) { + Tensor::Shape c_shape{a.shape()[a.shape().size() - 2], + b.shape()[b.shape().size() - 1]}; + return TensorLike::Empty(c_shape, a.dtype(), a.device()); + } + protected: float alpha_{1.0}; diff --git a/src/operator.h b/src/operator.h index 2dfb2ff7e..dc34d25bc 100644 --- a/src/operator.h +++ b/src/operator.h @@ -5,7 +5,9 @@ #include #include #include +#include #include +#include #include #include "config.h" @@ -72,6 +74,50 @@ bool ListContains(ValueType value, List) { return ((value == static_cast(values)) || ...); } +template +class IsTensorLike : public std::false_type {}; + +template +class IsTensorLike< + TensorLike, + std::void_t().data()), + decltype(std::declval().shape()), + decltype(std::declval().strides()), + decltype(std::declval().dtype()), + decltype(std::declval().device())>> + : public std::true_type {}; + +template >::value, int> = 0> +Tensor AsCallArg(const T& tensor) { + return Tensor{tensor}; +} + +template >::value, int> = 0> +const T& AsCallArg(const T& value) { + return value; +} + +template +class HasMakeReturnValueImpl : public std::false_type {}; + +template +class HasMakeReturnValueImpl< + Key, TensorLike, std::tuple, + std::void_t(), + std::declval()...))>> + : public std::true_type {}; + +template +class HasMakeReturnValueImpl> + : public std::false_type {}; + +template +class HasMakeReturnValue + : public HasMakeReturnValueImpl, + std::tuple> {}; + } // namespace infini::ops::detail template <> @@ -206,6 +252,14 @@ class Operator : public OperatorBase { return Call({}, {}, tensor, args...); } + template < + typename TensorLike, typename... Args, + typename std::enable_if_t< + detail::HasMakeReturnValue::value, int> = 0> + static auto Call(const TensorLike& tensor, const Args&... args) { + return CallReturning(tensor, args...); + } + static std::vector active_implementation_indices( Device::Type dev_type) { if (!detail::ListContains(dev_type, ActiveDevices{})) { @@ -245,6 +299,14 @@ class Operator : public OperatorBase { static constexpr std::size_t implementation_index_{implementation_index}; private: + template + static auto CallReturning(const TensorLike& tensor, const Args&... args) { + auto out = Key::MakeReturnValue(tensor, args...); + Key::Call(detail::AsCallArg(tensor), detail::AsCallArg(args)..., + detail::AsCallArg(out)); + return out; + } + template static std::unique_ptr MakeWithDevice( const Config& config, Device::Type dispatch_device_type, Args&&... args) { diff --git a/tests/test_cpp_api.py b/tests/test_cpp_api.py index c2274f61f..02933b8bf 100644 --- a/tests/test_cpp_api.py +++ b/tests/test_cpp_api.py @@ -23,6 +23,33 @@ def test_cpp_operator_call_instantiation_smoke(tmp_path): str(source), f"-L{library_dir}", "-linfiniops", + "-linfinirt", + f"-Wl,-rpath,{library_dir}", + "-o", + str(binary), + ] + ) + _run([str(binary)]) + + +def test_cpp_returning_call_smoke(tmp_path): + install_prefix = _install_prefix() + include_dir = install_prefix / "include" + library_dir = _library_dir(install_prefix) + source = tmp_path / "add_return_smoke.cc" + binary = tmp_path / "add_return_smoke" + source.write_text(_ADD_RETURN_SMOKE_SOURCE) + + _run( + [ + _compiler("CXX", "c++"), + "-std=c++17", + "-Werror", + f"-I{include_dir}", + str(source), + f"-L{library_dir}", + "-linfiniops", + "-linfinirt", f"-Wl,-rpath,{library_dir}", "-o", str(binary), @@ -113,3 +140,116 @@ def _run(command): } """ ).lstrip() + + +_ADD_RETURN_SMOKE_SOURCE = textwrap.dedent( + r""" + #include + + #include + #include + #include + #include + #include + #include + + class OwningTensor { + public: + using Shape = infini::ops::Tensor::Shape; + using Strides = infini::ops::Tensor::Strides; + + OwningTensor(std::vector data, Shape shape) + : data_{std::move(data)}, + shape_{std::move(shape)}, + strides_{ContiguousStrides(shape_)}, + dtype_{infini::ops::DataType::kFloat32}, + device_{infini::ops::Device::Type::kCpu} {} + + static OwningTensor Empty(const Shape& shape, infini::ops::DataType dtype, + infini::ops::Device device) { + if (dtype != infini::ops::DataType::kFloat32 || + device.type() != infini::ops::Device::Type::kCpu) { + throw std::runtime_error("unexpected output metadata"); + } + + return OwningTensor(std::vector(Numel(shape)), shape); + } + + void* data() { return data_.data(); } + + const void* data() const { return data_.data(); } + + const Shape& shape() const { return shape_; } + + const Strides& strides() const { return strides_; } + + infini::ops::DataType dtype() const { return dtype_; } + + infini::ops::Device device() const { return device_; } + + private: + static std::size_t Numel(const Shape& shape) { + return std::accumulate(shape.begin(), shape.end(), std::size_t{1}, + std::multiplies()); + } + + static Strides ContiguousStrides(const Shape& shape) { + if (shape.empty()) { + return {}; + } + + Strides strides(shape.size()); + strides.back() = 1; + for (std::ptrdiff_t i = static_cast(shape.size()) - 2; + i >= 0; --i) { + strides[static_cast(i)] = + strides[static_cast(i + 1)] * + static_cast( + shape[static_cast(i + 1)]); + } + return strides; + } + + std::vector data_; + Shape shape_; + Strides strides_; + infini::ops::DataType dtype_; + infini::ops::Device device_; + }; + + int main() { + OwningTensor input({1.0f, 2.0f, 3.0f}, {3}); + OwningTensor other({4.0f, 5.0f, 6.0f}, {3}); + + auto output = infini::ops::Add::Call(input, other); + const auto* data = static_cast(output.data()); + + if (output.shape() != OwningTensor::Shape{3}) { + return 1; + } + if (std::fabs(data[0] - 5.0f) > 1e-6f || + std::fabs(data[1] - 7.0f) > 1e-6f || + std::fabs(data[2] - 9.0f) > 1e-6f) { + return 1; + } + + OwningTensor a({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3}); + OwningTensor b({7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}, {3, 2}); + + auto c = infini::ops::Gemm::Call(a, b); + const auto* c_data = static_cast(c.data()); + + if (c.shape() != OwningTensor::Shape{2, 2}) { + return 1; + } + if (std::fabs(c_data[0] - 58.0f) > 1e-6f || + std::fabs(c_data[1] - 64.0f) > 1e-6f || + std::fabs(c_data[2] - 139.0f) > 1e-6f || + std::fabs(c_data[3] - 154.0f) > 1e-6f) { + return 1; + } + + return 0; + } + """ +).lstrip()