-
Notifications
You must be signed in to change notification settings - Fork 250
Expand file tree
/
Copy pathtensor.rs
More file actions
93 lines (82 loc) · 2.68 KB
/
tensor.rs
File metadata and controls
93 lines (82 loc) · 2.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
use std::{slice, sync::Arc, vec};
pub struct Tensor<T> {
data: Arc<Box<[T]>>,
shape: Vec<usize>,
offset: usize,
length: usize,
}
impl<T: Copy + Clone + Default> Tensor<T> {
pub fn new(data: Vec<T>, shape: &Vec<usize>) -> Self {
let length = data.len();
Tensor {
data: Arc::new(data.into_boxed_slice().try_into().unwrap()),
shape: shape.clone(),
offset: 0,
length: length,
}
}
pub fn default(shape: &Vec<usize>) -> Self {
let length = shape.iter().product();
let data = vec![T::default(); length];
Self::new(data, shape)
}
pub fn data(&self) -> &[T] {
&self.data[self.offset..][..self.length]
}
pub unsafe fn data_mut(&mut self) -> &mut [T] {
let ptr = self.data.as_ptr().add(self.offset) as *mut T;
slice::from_raw_parts_mut(ptr, self.length)
}
pub fn shape(&self) -> &Vec<usize> {
&self.shape
}
pub fn size(&self) -> usize {
self.length
}
// Reinterpret the tensor as a new shape while preserving total size.
pub fn reshape(&mut self, new_shape: &Vec<usize>) -> &mut Self {
let new_length: usize = new_shape.iter().product();
if new_length != self.length {
let old_shape = self.shape.clone();
panic!("New shape {new_shape:?} does not match tensor of {old_shape:?}");
}
self.shape = new_shape.clone();
self
}
pub fn slice(&self, start: usize, shape: &Vec<usize>) -> Self {
let new_length: usize = shape.iter().product();
assert!(new_length <= self.length && start <= self.length - new_length);
Tensor {
data: self.data.clone(),
shape: shape.clone(),
offset: self.offset + start,
length: new_length,
}
}
}
// Some helper functions for testing and debugging
impl Tensor<f32> {
#[allow(unused)]
pub fn close_to(&self, other: &Self, rel: f32) -> bool {
if self.shape() != other.shape() {
return false;
}
let a = self.data();
let b = other.data();
return a.iter().zip(b).all(|(x, y)| float_eq(x, y, rel));
}
#[allow(unused)]
pub fn print(&self){
println!("shpae: {:?}, offset: {}, length: {}", self.shape, self.offset, self.length);
let dim = self.shape()[self.shape().len() - 1];
let batch = self.length / dim;
for i in 0..batch {
let start = i * dim;
println!("{:?}", &self.data()[start..][..dim]);
}
}
}
#[inline]
pub fn float_eq(x: &f32, y: &f32, rel: f32) -> bool {
(x - y).abs() <= rel * (x.abs() + y.abs()) / 2.0
}