Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/base/add.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class Add : public Operator<Add> {
virtual void operator()(const Tensor input, const Tensor other,
Tensor out) const = 0;

template <typename TensorLike>
static auto MakeReturnValue(const TensorLike& input,
const TensorLike& other) {
return TensorLike::Empty(input.shape(), input.dtype(), input.device());
}

protected:
Tensor::Size ndim_{0};

Expand Down
7 changes: 7 additions & 0 deletions src/base/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ class Gemm : public Operator<Gemm> {
return operator()(a, b, alpha, beta, std::nullopt, std::nullopt, c);
}

template <typename TensorLike>
static auto MakeReturnValue(const TensorLike& a, const TensorLike& b) {
Tensor::Shape c_shape{a.shape()[a.shape().size() - 2],
b.shape()[b.shape().size() - 1]};
return TensorLike::Empty(c_shape, a.dtype(), a.device());
}

protected:
float alpha_{1.0};

Expand Down
62 changes: 62 additions & 0 deletions src/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include <memory>
#include <optional>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <vector>

#include "config.h"
Expand Down Expand Up @@ -72,6 +74,50 @@ bool ListContains(ValueType value, List<values...>) {
return ((value == static_cast<ValueType>(values)) || ...);
}

template <typename TensorLike, typename = void>
class IsTensorLike : public std::false_type {};

template <typename TensorLike>
class IsTensorLike<
TensorLike,
std::void_t<decltype(std::declval<const TensorLike&>().data()),
decltype(std::declval<const TensorLike&>().shape()),
decltype(std::declval<const TensorLike&>().strides()),
decltype(std::declval<const TensorLike&>().dtype()),
decltype(std::declval<const TensorLike&>().device())>>
: public std::true_type {};

template <typename T, typename std::enable_if_t<
IsTensorLike<std::decay_t<T>>::value, int> = 0>
Tensor AsCallArg(const T& tensor) {
return Tensor{tensor};
}

template <typename T, typename std::enable_if_t<
!IsTensorLike<std::decay_t<T>>::value, int> = 0>
const T& AsCallArg(const T& value) {
return value;
}

template <typename Key, typename TensorLike, typename Args, typename = void>
class HasMakeReturnValueImpl : public std::false_type {};

template <typename Key, typename TensorLike, typename... Args>
class HasMakeReturnValueImpl<
Key, TensorLike, std::tuple<Args...>,
std::void_t<decltype(Key::MakeReturnValue(std::declval<const TensorLike&>(),
std::declval<const Args&>()...))>>
: public std::true_type {};

template <typename Key, typename... Args>
class HasMakeReturnValueImpl<Key, Tensor, std::tuple<Args...>>
: public std::false_type {};

template <typename Key, typename TensorLike, typename... Args>
class HasMakeReturnValue
: public HasMakeReturnValueImpl<Key, std::decay_t<TensorLike>,
std::tuple<Args...>> {};

} // namespace infini::ops::detail

template <>
Expand Down Expand Up @@ -206,6 +252,14 @@ class Operator : public OperatorBase {
return Call({}, {}, tensor, args...);
}

template <
typename TensorLike, typename... Args,
typename std::enable_if_t<
detail::HasMakeReturnValue<Key, TensorLike, Args...>::value, int> = 0>
static auto Call(const TensorLike& tensor, const Args&... args) {
return CallReturning(tensor, args...);
}

static std::vector<std::size_t> active_implementation_indices(
Device::Type dev_type) {
if (!detail::ListContains(dev_type, ActiveDevices<Key>{})) {
Expand Down Expand Up @@ -245,6 +299,14 @@ class Operator : public OperatorBase {
static constexpr std::size_t implementation_index_{implementation_index};

private:
template <typename TensorLike, typename... Args>
static auto CallReturning(const TensorLike& tensor, const Args&... args) {
auto out = Key::MakeReturnValue(tensor, args...);
Key::Call(detail::AsCallArg(tensor), detail::AsCallArg(args)...,
detail::AsCallArg(out));
return out;
}

template <typename... Args>
static std::unique_ptr<Operator> MakeWithDevice(
const Config& config, Device::Type dispatch_device_type, Args&&... args) {
Expand Down
140 changes: 140 additions & 0 deletions tests/test_cpp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,33 @@ def test_cpp_operator_call_instantiation_smoke(tmp_path):
str(source),
f"-L{library_dir}",
"-linfiniops",
"-linfinirt",
f"-Wl,-rpath,{library_dir}",
"-o",
str(binary),
]
)
_run([str(binary)])


def test_cpp_returning_call_smoke(tmp_path):
install_prefix = _install_prefix()
include_dir = install_prefix / "include"
library_dir = _library_dir(install_prefix)
source = tmp_path / "add_return_smoke.cc"
binary = tmp_path / "add_return_smoke"
source.write_text(_ADD_RETURN_SMOKE_SOURCE)

_run(
[
_compiler("CXX", "c++"),
"-std=c++17",
"-Werror",
f"-I{include_dir}",
str(source),
f"-L{library_dir}",
"-linfiniops",
"-linfinirt",
f"-Wl,-rpath,{library_dir}",
"-o",
str(binary),
Expand Down Expand Up @@ -113,3 +140,116 @@ def _run(command):
}
"""
).lstrip()


_ADD_RETURN_SMOKE_SOURCE = textwrap.dedent(
r"""
#include <infini/ops.h>

#include <cmath>
#include <functional>
#include <numeric>
#include <stdexcept>
#include <utility>
#include <vector>

class OwningTensor {
public:
using Shape = infini::ops::Tensor::Shape;
using Strides = infini::ops::Tensor::Strides;

OwningTensor(std::vector<float> data, Shape shape)
: data_{std::move(data)},
shape_{std::move(shape)},
strides_{ContiguousStrides(shape_)},
dtype_{infini::ops::DataType::kFloat32},
device_{infini::ops::Device::Type::kCpu} {}

static OwningTensor Empty(const Shape& shape, infini::ops::DataType dtype,
infini::ops::Device device) {
if (dtype != infini::ops::DataType::kFloat32 ||
device.type() != infini::ops::Device::Type::kCpu) {
throw std::runtime_error("unexpected output metadata");
}

return OwningTensor(std::vector<float>(Numel(shape)), shape);
}

void* data() { return data_.data(); }

const void* data() const { return data_.data(); }

const Shape& shape() const { return shape_; }

const Strides& strides() const { return strides_; }

infini::ops::DataType dtype() const { return dtype_; }

infini::ops::Device device() const { return device_; }

private:
static std::size_t Numel(const Shape& shape) {
return std::accumulate(shape.begin(), shape.end(), std::size_t{1},
std::multiplies<std::size_t>());
}

static Strides ContiguousStrides(const Shape& shape) {
if (shape.empty()) {
return {};
}

Strides strides(shape.size());
strides.back() = 1;
for (std::ptrdiff_t i = static_cast<std::ptrdiff_t>(shape.size()) - 2;
i >= 0; --i) {
strides[static_cast<std::size_t>(i)] =
strides[static_cast<std::size_t>(i + 1)] *
static_cast<infini::ops::Tensor::Stride>(
shape[static_cast<std::size_t>(i + 1)]);
}
return strides;
}

std::vector<float> data_;
Shape shape_;
Strides strides_;
infini::ops::DataType dtype_;
infini::ops::Device device_;
};

int main() {
OwningTensor input({1.0f, 2.0f, 3.0f}, {3});
OwningTensor other({4.0f, 5.0f, 6.0f}, {3});

auto output = infini::ops::Add::Call(input, other);
const auto* data = static_cast<const float*>(output.data());

if (output.shape() != OwningTensor::Shape{3}) {
return 1;
}
if (std::fabs(data[0] - 5.0f) > 1e-6f ||
std::fabs(data[1] - 7.0f) > 1e-6f ||
std::fabs(data[2] - 9.0f) > 1e-6f) {
return 1;
}

OwningTensor a({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});
OwningTensor b({7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}, {3, 2});

auto c = infini::ops::Gemm::Call(a, b);
const auto* c_data = static_cast<const float*>(c.data());

if (c.shape() != OwningTensor::Shape{2, 2}) {
return 1;
}
if (std::fabs(c_data[0] - 58.0f) > 1e-6f ||
std::fabs(c_data[1] - 64.0f) > 1e-6f ||
std::fabs(c_data[2] - 139.0f) > 1e-6f ||
std::fabs(c_data[3] - 154.0f) > 1e-6f) {
return 1;
}

return 0;
}
"""
).lstrip()
Loading