Skip to content

Commit 84ccb45

Browse files
author
Changming Sun
committed
update
1 parent 47a557a commit 84ccb45

5 files changed

Lines changed: 120 additions & 6 deletions

File tree

tests/unittest/matrix_buffer.h

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include <cstdlib> // For malloc, free, abort
7+
#include <cstddef> // For size_t
8+
#include <functional> // For std::function
9+
#include <algorithm> // For std::fill_n
10+
#include <new> // For std::bad_alloc (alternative to abort)
11+
12+
// Include crtdbg.h for _malloc_dbg and _free_dbg on Windows debug builds
13+
#if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG)
14+
#include <crtdbg.h>
15+
#endif
16+
17+
template <typename T>
18+
class MatrixGuardBuffer {
19+
public:
20+
MatrixGuardBuffer() :
21+
buffer_(nullptr),
22+
elements_allocated_(0) {
23+
}
24+
25+
~MatrixGuardBuffer() {
26+
ReleaseBuffer();
27+
}
28+
29+
// Disable copy and move semantics for simplicity
30+
MatrixGuardBuffer(const MatrixGuardBuffer&) = delete;
31+
MatrixGuardBuffer& operator=(const MatrixGuardBuffer&) = delete;
32+
MatrixGuardBuffer(MatrixGuardBuffer&&) = delete;
33+
MatrixGuardBuffer& operator=(MatrixGuardBuffer&&) = delete;
34+
35+
T* GetFilledBuffer(size_t elements, const std::function<void(T*, size_t)>& fill_func) {
36+
if (elements == 0) {
37+
ReleaseBuffer();
38+
return nullptr;
39+
}
40+
41+
if (elements > elements_allocated_) {
42+
ReleaseBuffer();
43+
44+
size_t bytes_to_allocate = elements * sizeof(T);
45+
if (elements != 0 && bytes_to_allocate / elements != sizeof(T)) { // Check for overflow before multiplication
46+
// Handle overflow, e.g., by aborting or throwing
47+
abort();
48+
}
49+
50+
#if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG)
51+
buffer_ = static_cast<T*>(_malloc_dbg(bytes_to_allocate, _NORMAL_BLOCK, __FILE__, __LINE__));
52+
#else
53+
buffer_ = static_cast<T*>(malloc(bytes_to_allocate));
54+
#endif
55+
56+
if (buffer_ == nullptr) {
57+
// Consider `throw std::bad_alloc();` for C++ style error handling.
58+
abort();
59+
}
60+
elements_allocated_ = elements;
61+
}
62+
63+
if (fill_func && buffer_ != nullptr) {
64+
fill_func(buffer_, elements);
65+
} else if (buffer_ == nullptr && elements > 0) {
66+
abort(); // Should not happen if allocation failure aborts
67+
}
68+
69+
return buffer_;
70+
}
71+
72+
T* GetBuffer(size_t elements, bool zero_fill = false) {
73+
if (zero_fill) {
74+
return GetFilledBuffer(
75+
elements,
76+
[](T* start, size_t count) {
77+
if (start && count > 0) {
78+
std::fill_n(start, count, T{}); // Value-initialize
79+
}
80+
});
81+
}
82+
83+
return GetFilledBuffer(
84+
elements,
85+
[](T* start, size_t count) {
86+
//do nothing, so that we can catch read uninitialized values errors
87+
});
88+
}
89+
90+
void ReleaseBuffer() {
91+
if (buffer_ != nullptr) {
92+
#if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG)
93+
_free_dbg(buffer_, _NORMAL_BLOCK);
94+
#else
95+
free(buffer_);
96+
#endif
97+
buffer_ = nullptr;
98+
}
99+
elements_allocated_ = 0;
100+
}
101+
102+
private:
103+
T* buffer_;
104+
size_t elements_allocated_;
105+
};

tests/unittest/test_fgemm.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class MlasFgemmTest : public MlasTestBase {
202202
for (size_t m = 0; m < M; m++) {
203203
for (size_t n = 0; n < N; n++, f++) {
204204
// Sensitive to comparing positive/negative zero.
205-
ASSERT_EQ(C[f], CReference[f])
205+
ASSERT_NEAR(C[f], CReference[f],1e-5)
206206
<< " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", "
207207
<< (Packed ? "Packed" : "NoPack") << "."
208208
<< (Threaded ? "SingleThread" : "Threaded") << "/"

tests/unittest/test_main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ bool AddTestRegister(TestRegister test_register) {
5757
}
5858

5959
int main(int argc, char** argv) {
60+
unsigned int current_control;
61+
_controlfp_s(&current_control, 0, 0); // Get current control word
62+
_controlfp_s(&current_control, ~(_EM_INVALID | _EM_ZERODIVIDE | _EM_DENORMAL), _MCW_EM); // Unmask exceptions
6063
bool is_short_execute = (argc <= 1 || strcmp("--long", argv[1]) != 0);
6164
std::cout << "-------------------------------------------------------" << std::endl;
6265
if (is_short_execute) {

tests/unittest/test_scaleoutput.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class MlasScaleOutputTest : public MlasTestBase {
2222
std::numeric_limits<int16_t>::max());
2323

2424
for (size_t s = 0; s < M * N; s++) {
25-
Input[s] = int_distribution(generator);
25+
Input[s] = int_distribution(generator); //It could be zero
2626
Output[s] = OutputRef[s] = real_distribution(generator);
2727
}
2828

@@ -52,10 +52,14 @@ class MlasScaleOutputTest : public MlasTestBase {
5252
constexpr float epsilon = 1e-6f;
5353

5454
for (size_t n = 0; n < M * N; n++) {
55-
float diff = std::fabs((Output[n] - OutputRef[n]) / OutputRef[n]);
55+
float outvalue = OutputRef[n]; // When `AccumulateMode` is false, there is a high chance that this value could be zero
56+
float diff = std::fabs(Output[n] - outvalue) ;
57+
if (outvalue != 0) {
58+
diff /= outvalue;
59+
}
5660
ASSERT_LE(diff, epsilon)
5761
<< " @[" << n / N << "," << n % N << "], total:[" << M << "," << N << "], got:"
58-
<< Output[n] << ", expecting:" << OutputRef[n];
62+
<< Output[n] << ", expecting:" << outvalue;
5963
}
6064
}
6165

tests/unittest/test_util.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737
#endif
3838

3939
MLAS_THREADPOOL* GetMlasThreadPool(void);
40-
40+
#ifdef BUILD_MLAS_NO_ONNXRUNTIME
41+
#include "matrix_buffer.h"
42+
#else
4143
template <typename T>
4244
class MatrixGuardBuffer {
4345
public:
@@ -163,7 +165,7 @@ class MatrixGuardBuffer {
163165
size_t _BaseBufferSize;
164166
T* _GuardAddress;
165167
};
166-
168+
#endif
167169
class MlasTestBase {
168170
public:
169171
virtual ~MlasTestBase(void) {}

0 commit comments

Comments
 (0)