-
Notifications
You must be signed in to change notification settings - Fork 972
Expand file tree
/
Copy pathQnnManager.h
More file actions
118 lines (100 loc) · 3.79 KB
/
QnnManager.h
File metadata and controls
118 lines (100 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
/*
* Copyright (c) Qualcomm Innovation Center, Inc.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <executorch/backends/qualcomm/aot/wrappers/OpWrapper.h>
#include <executorch/backends/qualcomm/aot/wrappers/TensorWrapper.h>
#include <executorch/backends/qualcomm/runtime/Logging.h>
#include <executorch/backends/qualcomm/runtime/QnnExecuTorch.h>
#include <executorch/backends/qualcomm/runtime/backends/QnnBackendFactory.h>
#include <executorch/backends/qualcomm/schema_generated.h>
#include <executorch/runtime/core/error.h>
#include <memory>
#include <unordered_map>
namespace torch {
namespace executor {
namespace qnn {
class QnnManager {
public:
// Construct QnnManager
explicit QnnManager(
const QnnExecuTorchOptions* options,
const QnnExecuTorchContextBinary& qnn_executorch_context_binary);
~QnnManager();
Error Init();
Error AllocateTensor();
Error AllocateTensor(
std::vector<std::shared_ptr<TensorWrapper>>& inputs,
std::vector<std::shared_ptr<TensorWrapper>>& outputs);
Error Execute(
const std::vector<Qnn_Tensor_t>& input_tensor_structs,
std::vector<Qnn_Tensor_t>& output_tensor_structs,
EventTracer* event_tracer);
Error ProfileExecuteData(EventTracer* event_tracer);
void Destroy();
bool IsAvailable() {
return true;
}
bool IsOnlinePrepare() {
return options_->online_prepare();
}
bool IsTensorDump() {
return options_->dump_intermediate_outputs();
}
bool IsNodeSupportedByBackend(
std::vector<std::shared_ptr<OpWrapper>>& op_wrappers);
Error Compile(
std::vector<std::shared_ptr<OpWrapper>>& op_wrappers,
QnnExecuTorchContextBinary& qnn_executorch_context_binary);
Error RegisterMem(
void* data_ptr,
const std::shared_ptr<TensorWrapper>& tensor_wrapper);
// Pre-register custom memory handle from the SharedBuffer before execution
Error PreRegisterMem();
std::vector<std::shared_ptr<TensorWrapper>> GetGraphInputs() {
return input_tensors_;
}
std::vector<std::shared_ptr<TensorWrapper>> GetGraphOutputs() {
return output_tensors_;
}
private:
Error LoadQnnLibrary();
#ifdef _WIN32
static constexpr const char* htp_library_name_ = "QnnHtp.dll";
static constexpr const char* gpu_library_name_ = "QnnGpu.dll";
static constexpr const char* dsp_library_name_ = "QnnDsp.dll";
#else
static constexpr const char* htp_library_name_ = "libQnnHtp.so";
static constexpr const char* gpu_library_name_ = "libQnnGpu.so";
static constexpr const char* dsp_library_name_ = "libQnnDsp.so";
#endif
QnnExecuTorchContextBinary qnn_context_blob_;
std::unique_ptr<BackendConfigParameters> backend_params_ptr_;
QnnImplementation qnn_loaded_backend_;
std::unique_ptr<QnnLogger> logger_;
const QnnExecuTorchOptions* options_;
std::vector<std::shared_ptr<TensorWrapper>> input_tensors_;
std::vector<std::shared_ptr<TensorWrapper>> output_tensors_;
Error RegisterIonMem(
void* data_ptr,
const std::shared_ptr<TensorWrapper>& tensor_wrapper);
Error RegisterCustomMem(
void* data_ptr,
void* custom_mem_base,
const std::shared_ptr<TensorWrapper>& tensor_wrapper);
std::unordered_map<Qnn_DataType_t, ScalarType> qnn_dtype_to_scalar_type_ = {
{Qnn_DataType_t::QNN_DATATYPE_INT_32, ScalarType::Int},
{Qnn_DataType_t::QNN_DATATYPE_FLOAT_32, ScalarType::Float},
{Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_8, ScalarType::Char},
{Qnn_DataType_t::QNN_DATATYPE_SFIXED_POINT_16, ScalarType::Short},
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_8, ScalarType::Byte},
{Qnn_DataType_t::QNN_DATATYPE_UFIXED_POINT_16, ScalarType::Bits16},
};
};
} // namespace qnn
} // namespace executor
} // namespace torch