1+ // Copyright (c) Microsoft Corporation. All rights reserved.
2+ // Licensed under the MIT License.
3+
4+ #pragma once
5+
6+ #include < cstdlib> // For malloc, free, abort
7+ #include < cstddef> // For size_t
8+ #include < functional> // For std::function
9+ #include < algorithm> // For std::fill_n
10+ #include < new> // For std::bad_alloc (alternative to abort)
11+
12+ // Include crtdbg.h for _malloc_dbg and _free_dbg on Windows debug builds
13+ #if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG)
14+ #include < crtdbg.h>
15+ #endif
16+
17+ template <typename T>
18+ class MatrixGuardBuffer {
19+ public:
20+ MatrixGuardBuffer () :
21+ buffer_ (nullptr ),
22+ elements_allocated_ (0 ) {
23+ }
24+
25+ ~MatrixGuardBuffer () {
26+ ReleaseBuffer ();
27+ }
28+
29+ // Disable copy and move semantics for simplicity
30+ MatrixGuardBuffer (const MatrixGuardBuffer&) = delete ;
31+ MatrixGuardBuffer& operator =(const MatrixGuardBuffer&) = delete ;
32+ MatrixGuardBuffer (MatrixGuardBuffer&&) = delete ;
33+ MatrixGuardBuffer& operator =(MatrixGuardBuffer&&) = delete ;
34+
35+ T* GetFilledBuffer (size_t elements, const std::function<void (T*, size_t )>& fill_func) {
36+ if (elements == 0 ) {
37+ ReleaseBuffer ();
38+ return nullptr ;
39+ }
40+
41+ if (elements > elements_allocated_) {
42+ ReleaseBuffer ();
43+
44+ size_t bytes_to_allocate = elements * sizeof (T);
45+ if (elements != 0 && bytes_to_allocate / elements != sizeof (T)) { // Check for overflow before multiplication
46+ // Handle overflow, e.g., by aborting or throwing
47+ abort ();
48+ }
49+
50+ #if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG)
51+ buffer_ = static_cast <T*>(_malloc_dbg (bytes_to_allocate, _NORMAL_BLOCK, __FILE__, __LINE__));
52+ #else
53+ buffer_ = static_cast <T*>(malloc (bytes_to_allocate));
54+ #endif
55+
56+ if (buffer_ == nullptr ) {
57+ // Consider `throw std::bad_alloc();` for C++ style error handling.
58+ abort ();
59+ }
60+ elements_allocated_ = elements;
61+ }
62+
63+ if (fill_func && buffer_ != nullptr ) {
64+ fill_func (buffer_, elements);
65+ } else if (buffer_ == nullptr && elements > 0 ) {
66+ abort (); // Should not happen if allocation failure aborts
67+ }
68+
69+ return buffer_;
70+ }
71+
72+ T* GetBuffer (size_t elements, bool zero_fill = false ) {
73+ if (zero_fill) {
74+ return GetFilledBuffer (
75+ elements,
76+ [](T* start, size_t count) {
77+ if (start && count > 0 ) {
78+ std::fill_n (start, count, T{}); // Value-initialize
79+ }
80+ });
81+ }
82+
83+ return GetFilledBuffer (
84+ elements,
85+ [](T* start, size_t count) {
86+ // do nothing, so that we can catch read uninitialized values errors
87+ });
88+ }
89+
90+ void ReleaseBuffer () {
91+ if (buffer_ != nullptr ) {
92+ #if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG)
93+ _free_dbg (buffer_, _NORMAL_BLOCK);
94+ #else
95+ free (buffer_);
96+ #endif
97+ buffer_ = nullptr ;
98+ }
99+ elements_allocated_ = 0 ;
100+ }
101+
102+ private:
103+ T* buffer_;
104+ size_t elements_allocated_;
105+ };
0 commit comments