Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/base/add.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ namespace infini::ops {

class Add : public Operator<Add> {
public:
using Operator<Add>::Call;

Add(const Tensor input, const Tensor other, Tensor out)
: ndim_{out.ndim()},
output_size_{out.numel()},
Expand Down Expand Up @@ -36,6 +38,21 @@ class Add : public Operator<Add> {
virtual void operator()(const Tensor input, const Tensor other,
Tensor out) const = 0;

template <typename TensorLike>
static auto Call(const TensorLike& input, const TensorLike& other) {
auto out = TensorLike::Empty(input.shape(), input.dtype(), input.device());
Tensor input_view{const_cast<void*>(static_cast<const void*>(input.data())),
input.shape(), input.dtype(), input.device(),
input.strides()};
Tensor other_view{const_cast<void*>(static_cast<const void*>(other.data())),
other.shape(), other.dtype(), other.device(),
other.strides()};
Tensor out_view{out.data(), out.shape(), out.dtype(), out.device(),
out.strides()};
Add::Call(input_view, other_view, out_view);
return out;
}

protected:
Tensor::Size ndim_{0};

Expand Down
29 changes: 29 additions & 0 deletions src/base/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace infini::ops {

class Gemm : public Operator<Gemm> {
public:
using Operator<Gemm>::Call;

Gemm(const Tensor a, const Tensor b, std::optional<float> alpha,
std::optional<float> beta, std::optional<int> trans_a,
std::optional<int> trans_b, Tensor c)
Expand Down Expand Up @@ -55,6 +57,33 @@ class Gemm : public Operator<Gemm> {
return operator()(a, b, alpha, beta, std::nullopt, std::nullopt, c);
}

template <typename TensorLike>
static auto Call(const TensorLike& a, const TensorLike& b) {
return Call(a, b, false, false);
}

template <typename TensorLike>
static auto Call(const TensorLike& a, const TensorLike& b, bool trans_a,
bool trans_b) {
Tensor::Shape c_shape{
trans_a ? a.shape()[a.shape().size() - 1]
: a.shape()[a.shape().size() - 2],
trans_b ? b.shape()[b.shape().size() - 2]
: b.shape()[b.shape().size() - 1],
};
auto c = TensorLike::Empty(c_shape, a.dtype(), a.device());
Tensor a_view{const_cast<void*>(static_cast<const void*>(a.data())),
a.shape(), a.dtype(), a.device(), a.strides()};
Tensor b_view{const_cast<void*>(static_cast<const void*>(b.data())),
b.shape(), b.dtype(), b.device(), b.strides()};
Tensor c_view{c.data(), c.shape(), c.dtype(), c.device(), c.strides()};
Gemm::Call(a_view, b_view, std::optional<float>{1.0f},
std::optional<float>{0.0f},
std::optional<int>{static_cast<int>(trans_a)},
std::optional<int>{static_cast<int>(trans_b)}, c_view);
return c;
}

protected:
float alpha_{1.0};

Expand Down
140 changes: 140 additions & 0 deletions tests/test_cpp_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,33 @@ def test_cpp_operator_call_instantiation_smoke(tmp_path):
str(source),
f"-L{library_dir}",
"-linfiniops",
"-linfinirt",
f"-Wl,-rpath,{library_dir}",
"-o",
str(binary),
]
)
_run([str(binary)])


def test_cpp_returning_call_smoke(tmp_path):
install_prefix = _install_prefix()
include_dir = install_prefix / "include"
library_dir = _library_dir(install_prefix)
source = tmp_path / "add_return_smoke.cc"
binary = tmp_path / "add_return_smoke"
source.write_text(_ADD_RETURN_SMOKE_SOURCE)

_run(
[
_compiler("CXX", "c++"),
"-std=c++17",
"-Werror",
f"-I{include_dir}",
str(source),
f"-L{library_dir}",
"-linfiniops",
"-linfinirt",
f"-Wl,-rpath,{library_dir}",
"-o",
str(binary),
Expand Down Expand Up @@ -113,3 +140,116 @@ def _run(command):
}
"""
).lstrip()


_ADD_RETURN_SMOKE_SOURCE = textwrap.dedent(
r"""
#include <infini/ops.h>

#include <cmath>
#include <functional>
#include <numeric>
#include <stdexcept>
#include <utility>
#include <vector>

class OwningTensor {
public:
using Shape = infini::ops::Tensor::Shape;
using Strides = infini::ops::Tensor::Strides;

OwningTensor(std::vector<float> data, Shape shape)
: data_{std::move(data)},
shape_{std::move(shape)},
strides_{ContiguousStrides(shape_)},
dtype_{infini::ops::DataType::kFloat32},
device_{infini::ops::Device::Type::kCpu} {}

static OwningTensor Empty(const Shape& shape, infini::ops::DataType dtype,
infini::ops::Device device) {
if (dtype != infini::ops::DataType::kFloat32 ||
device.type() != infini::ops::Device::Type::kCpu) {
throw std::runtime_error("unexpected output metadata");
}

return OwningTensor(std::vector<float>(Numel(shape)), shape);
}

void* data() { return data_.data(); }

const void* data() const { return data_.data(); }

const Shape& shape() const { return shape_; }

const Strides& strides() const { return strides_; }

infini::ops::DataType dtype() const { return dtype_; }

infini::ops::Device device() const { return device_; }

private:
static std::size_t Numel(const Shape& shape) {
return std::accumulate(shape.begin(), shape.end(), std::size_t{1},
std::multiplies<std::size_t>());
}

static Strides ContiguousStrides(const Shape& shape) {
if (shape.empty()) {
return {};
}

Strides strides(shape.size());
strides.back() = 1;
for (std::ptrdiff_t i = static_cast<std::ptrdiff_t>(shape.size()) - 2;
i >= 0; --i) {
strides[static_cast<std::size_t>(i)] =
strides[static_cast<std::size_t>(i + 1)] *
static_cast<infini::ops::Tensor::Stride>(
shape[static_cast<std::size_t>(i + 1)]);
}
return strides;
}

std::vector<float> data_;
Shape shape_;
Strides strides_;
infini::ops::DataType dtype_;
infini::ops::Device device_;
};

int main() {
OwningTensor input({1.0f, 2.0f, 3.0f}, {3});
OwningTensor other({4.0f, 5.0f, 6.0f}, {3});

auto output = infini::ops::Add::Call(input, other);
const auto* data = static_cast<const float*>(output.data());

if (output.shape() != OwningTensor::Shape{3}) {
return 1;
}
if (std::fabs(data[0] - 5.0f) > 1e-6f ||
std::fabs(data[1] - 7.0f) > 1e-6f ||
std::fabs(data[2] - 9.0f) > 1e-6f) {
return 1;
}

OwningTensor a({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {2, 3});
OwningTensor b({7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}, {3, 2});

auto c = infini::ops::Gemm::Call(a, b);
const auto* c_data = static_cast<const float*>(c.data());

if (c.shape() != OwningTensor::Shape{2, 2}) {
return 1;
}
if (std::fabs(c_data[0] - 58.0f) > 1e-6f ||
std::fabs(c_data[1] - 64.0f) > 1e-6f ||
std::fabs(c_data[2] - 139.0f) > 1e-6f ||
std::fabs(c_data[3] - 154.0f) > 1e-6f) {
return 1;
}

return 0;
}
"""
).lstrip()
Loading