This repository was archived by the owner on Jan 26, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 222
Expand file tree
/
Copy pathcommon.h
More file actions
72 lines (61 loc) · 1.91 KB
/
common.h
File metadata and controls
72 lines (61 loc) · 1.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#ifndef LOGREG_UTIL_HELPER_H_
#define LOGREG_UTIL_HELPER_H_
#include "data_type.h"
namespace logreg {
template <typename EleType>
EleType** CreateMatrix(int num_row, int num_col) {
EleType **matrix = new EleType*[num_row];
for (int i = 0; i < num_row; ++i)
matrix[i] = new EleType[num_col];
return matrix;
}
template <typename EleType>
void FreeMatrix(int num_row, EleType**matrix) {
for (int i = 0; i < num_row; ++i)
delete[]matrix[i];
delete[]matrix;
}
template <typename EleType>
EleType Dot(size_t offset, DataBlock<EleType>*matrix, Sample<EleType>*sample) {
EleType sum = 0;
int size = static_cast<int>(sample->values.size());
if (matrix->sparse()) {
DEBUG_CHECK(sample->keys.size() == sample->values.size());
for (int i = 0; i < size; ++i) {
EleType* pval = matrix->Get(sample->keys[i] + offset);
sum += (pval == nullptr ? 0 : (sample->values[i] * (*pval)));
}
} else {
EleType*rawa = static_cast<EleType*>(matrix->raw()) + offset;
EleType*rawb = sample->values.data();
for (int i = 0; i < size; ++i) {
sum += rawa[i] * rawb[i];
}
}
return sum;
}
template <typename EleType>
inline EleType* MatrixRow(EleType*matrix, int row_id, size_t num_col) {
return matrix + row_id * num_col;
}
template <typename EleType>
Sample<EleType>** CreateSamples(int num, size_t size, bool sparse) {
Sample<EleType>**samples = new Sample<EleType>*[num];
for (int i = 0; i < num; ++i) {
samples[i] = new Sample<EleType>(sparse, size);
}
return samples;
}
template <typename EleType>
void FreeSamples(int num, Sample<EleType>**samples) {
for (int i = 0; i < num; ++i) {
delete samples[i];
}
delete[]samples;
}
#define DECLARE_TEMPLATE_CLASS_WITH_BASIC_TYPE(name) \
template class name<int>; \
template class name<float>; \
template class name<double>;
} // namespace logreg
#endif // LOGREG_UTIL_HELPER_H_