forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcpu_execution_provider.h
More file actions
61 lines (49 loc) · 2.01 KB
/
cpu_execution_provider.h
File metadata and controls
61 lines (49 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/framework/allocatormgr.h"
#include "core/framework/execution_provider.h"
#include "core/graph/constants.h"
namespace onnxruntime {
// Information needed to construct CPU execution providers.
struct CPUExecutionProviderInfo {
bool create_arena{true};
explicit CPUExecutionProviderInfo(bool use_arena)
: create_arena(use_arena) {}
CPUExecutionProviderInfo() = default;
};
using FuseRuleFn = std::function<void(const onnxruntime::GraphViewer&,
std::vector<std::unique_ptr<ComputeCapability>>&)>;
// Logical device representation.
class CPUExecutionProvider : public IExecutionProvider {
public:
explicit CPUExecutionProvider(const CPUExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kCpuExecutionProvider} {
DeviceAllocatorRegistrationInfo device_info{OrtMemTypeDefault,
[](int) { return onnxruntime::make_unique<TAllocator>(); },
std::numeric_limits<size_t>::max()};
#ifdef USE_JEMALLOC
#if defined(USE_MIMALLOC_ARENA_ALLOCATOR) || defined(USE_MIMALLOC_STL_ALLOCATOR)
#error jemalloc and mimalloc should not both be enabled
#endif
ORT_UNUSED_PARAMETER(info);
//JEMalloc already has memory pool, so just use device allocator.
InsertAllocator(device_info.factory(0));
#else
//Disable Arena allocator for x86_32 build because it may run into infinite loop when integer overflow happens
#if defined(__amd64__) || defined(_M_AMD64)
if (info.create_arena)
InsertAllocator(CreateAllocator(device_info));
else
#else
ORT_UNUSED_PARAMETER(info);
#endif
InsertAllocator(device_info.factory(0));
#endif
}
std::shared_ptr<KernelRegistry> GetKernelRegistry() const override;
std::unique_ptr<IDataTransfer> GetDataTransfer() const override;
private:
std::vector<FuseRuleFn> fuse_rules_;
};
} // namespace onnxruntime