Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions runtime/core/device_allocator.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/runtime/core/device_allocator.h>

#include <executorch/runtime/platform/assert.h>

namespace executorch {
namespace runtime {

DeviceAllocatorRegistry& DeviceAllocatorRegistry::instance() {
static DeviceAllocatorRegistry registry;
return registry;
}

void DeviceAllocatorRegistry::register_allocator(
etensor::DeviceType type,
DeviceAllocator* alloc) {
auto index = static_cast<size_t>(type);
ET_CHECK_MSG(
index < etensor::kNumDeviceTypes,
"Invalid device type: %d",
static_cast<int>(type));
ET_CHECK_MSG(
allocators_[index] == nullptr,
"Allocator already registered for device type: %d",
static_cast<int>(type));
allocators_[index] = alloc;
}

DeviceAllocator* DeviceAllocatorRegistry::get_allocator(
etensor::DeviceType type) {
auto index = static_cast<size_t>(type);
if (index >= etensor::kNumDeviceTypes) {
return nullptr;
}
return allocators_[index];
}

// Convenience free functions

void register_device_allocator(
etensor::DeviceType type,
DeviceAllocator* alloc) {
DeviceAllocatorRegistry::instance().register_allocator(type, alloc);
}

DeviceAllocator* get_device_allocator(etensor::DeviceType type) {
return DeviceAllocatorRegistry::instance().get_allocator(type);
}

} // namespace runtime
} // namespace executorch
156 changes: 156 additions & 0 deletions runtime/core/device_allocator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <cstddef>

#include <executorch/runtime/core/error.h>
#include <executorch/runtime/core/portable_type/device.h>
#include <executorch/runtime/core/result.h>

namespace executorch {
namespace runtime {

/**
* Abstract interface for device-specific memory allocation.
*
* Each device type (CUDA, etc.) provides a concrete implementation
* that handles memory allocation on that device. Implementations are
* expected to be singletons with static lifetime, registered via
* DeviceAllocatorRegistry.

*/
class DeviceAllocator {
public:
virtual ~DeviceAllocator() = default;
/**
* Allocate device memory.
*
* @param nbytes Number of bytes to allocate.
* @param index The device index.
* @return A Result containing the device pointer on success, or an error.
*/
virtual Result<void*> allocate(size_t nbytes, etensor::DeviceIndex index) = 0;

/**
* Deallocate device memory previously allocated via allocate().
*
* @param ptr Pointer to the memory to deallocate.
* @param index The device index.
*/
virtual void deallocate(void* ptr, etensor::DeviceIndex index) = 0;

/**
* Copy data from host memory to device memory.
*
* @param dst Destination pointer (device memory).
* @param src Source pointer (host memory).
* @param nbytes Number of bytes to copy.
* @param index The device index.
* @return Error::Ok on success, or an appropriate error code on failure.
*/
virtual Error copy_host_to_device(
void* dst,
const void* src,
size_t nbytes,
etensor::DeviceIndex index) = 0;

/**
* Copy data from device memory to host memory.
*
* @param dst Destination pointer (host memory).
* @param src Source pointer (device memory).
* @param nbytes Number of bytes to copy.
* @param index The device index.
* @return Error::Ok on success, or an appropriate error code on failure.
*/
virtual Error copy_device_to_host(
void* dst,
const void* src,
size_t nbytes,
etensor::DeviceIndex index) = 0;

/**
* Returns the device type this allocator handles.
*/
virtual etensor::DeviceType device_type() const = 0;
};

/**
* Registry for device allocators.
*
* Provides a global mapping from DeviceType to DeviceAllocator instances.
* Device allocators register themselves at static initialization time,
* and the runtime queries the registry to find the appropriate allocator
* for a given device type.
*/
class DeviceAllocatorRegistry {
public:
/**
* Returns the singleton instance of the registry.
*/
static DeviceAllocatorRegistry& instance();

/**
* Register an allocator for a specific device type.
*
* @param type The device type this allocator handles.
* @param alloc Pointer to the allocator (must have static lifetime).
*/
void register_allocator(etensor::DeviceType type, DeviceAllocator* alloc);

/**
* Get the allocator for a specific device type.
*
* @param type The device type.
* @return Pointer to the allocator, or nullptr if not registered.
*/
DeviceAllocator* get_allocator(etensor::DeviceType type);

private:
DeviceAllocatorRegistry() = default;

// Fixed-size array indexed by device type. This avoids dynamic allocation
// and is suitable for embedded environments.
DeviceAllocator* allocators_[etensor::kNumDeviceTypes] = {};
};

// Convenience free functions

/**
* Register a device allocator for a specific device type.
*
* @param type The device type this allocator handles.
* @param alloc Pointer to the allocator (must have static lifetime).
*/
void register_device_allocator(
etensor::DeviceType type,
DeviceAllocator* alloc);

/**
* Get the device allocator for a specific device type.
*
* @param type The device type.
* @return Pointer to the allocator, or nullptr if not registered.
*/
DeviceAllocator* get_device_allocator(etensor::DeviceType type);

} // namespace runtime
} // namespace executorch

namespace torch {
namespace executor {
// TODO(T197294990): Remove these deprecated aliases once all users have moved
// to the new `::executorch` namespaces.
using ::executorch::runtime::DeviceAllocator;
using ::executorch::runtime::DeviceAllocatorRegistry;
using ::executorch::runtime::get_device_allocator;
using ::executorch::runtime::register_device_allocator;
} // namespace executor
} // namespace torch
Loading
Loading