forked from graphcore-research/gfloat
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathround_ndarray.py
More file actions
191 lines (145 loc) · 6.34 KB
/
round_ndarray.py
File metadata and controls
191 lines (145 loc) · 6.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.
from typing import Optional
from types import ModuleType
from .types import FormatInfo, RoundMode
import numpy as np
import array_api_compat
def _isodd(v: np.ndarray) -> np.ndarray:
return v & 0x1 == 1
def _ldexp(v: np.ndarray, s: np.ndarray) -> np.ndarray:
xp = array_api_compat.array_namespace(v, s)
if (
array_api_compat.is_torch_array(v)
or array_api_compat.is_jax_array(v)
or array_api_compat.is_numpy_array(v)
):
return xp.ldexp(v, s)
# Scale away from subnormal/infinite ranges
offset = 24
vlo = (v * 2.0**+offset) * 2.0 ** xp.astype(s - offset, v.dtype)
vhi = (v * 2.0**-offset) * 2.0 ** xp.astype(s + offset, v.dtype)
return xp.where(v < 1.0, vlo, vhi)
def round_ndarray(
fi: FormatInfo,
v: np.ndarray,
rnd: RoundMode = RoundMode.TiesToEven,
sat: bool = False,
srbits: Optional[np.ndarray] = None,
srnumbits: int = 0,
) -> np.ndarray:
"""
Vectorized version of :meth:`round_float`.
Round inputs to the given :py:class:`FormatInfo`, given rounding mode and
saturation flag
Input NaNs will convert to NaNs in the target, not necessarily preserving payload.
An input Infinity will convert to the largest float if :paramref:`sat`,
otherwise to an Inf, if present, otherwise to a NaN.
Negative zero will be returned if the format has negative zero, otherwise zero.
Args:
fi (FormatInfo): Describes the target format
v (float array): Input values to be rounded
rnd (RoundMode): Rounding mode to use
sat (bool): Saturation flag: if True, round overflowed values to `fi.max`
srbits (int array): Bits to use for stochastic rounding if rnd == Stochastic.
srnumbits (int): How many bits are in srbits. Implies srbits < 2**srnumbits.
Returns:
An array of floats which is a subset of the format's value set.
Raises:
ValueError: The target format cannot represent an input
(e.g. converting a `NaN`, or an `Inf` when the target has no
`NaN` or `Inf`, and :paramref:`sat` is false)
"""
xp = array_api_compat.array_namespace(v, srbits)
p = fi.precision
bias = fi.expBias
is_negative = xp.signbit(v) & fi.is_signed
absv = xp.where(is_negative, -v, v)
finite_nonzero = ~(xp.isnan(v) | xp.isinf(v) | (v == 0))
# Place 1.0 where finite_nonzero is False, to avoid log of {0,inf,nan}
absv_masked = xp.where(finite_nonzero, absv, 1.0)
int_type = xp.int64 if fi.k > 8 or srnumbits > 8 else xp.int16
def to_int(x: np.ndarray) -> np.ndarray:
return xp.astype(x, int_type)
def to_float(x: np.ndarray) -> np.ndarray:
return xp.astype(x, v.dtype)
expval = to_int(xp.floor(xp.log2(absv_masked)))
if fi.has_subnormals:
expval = xp.maximum(expval, 1 - bias)
expval = expval - p + 1
fsignificand = _ldexp(absv_masked, -expval)
floorfsignificand = xp.floor(fsignificand)
isignificand = to_int(floorfsignificand)
delta = fsignificand - floorfsignificand
if fi.precision > 1:
code_is_odd = _isodd(isignificand)
else:
code_is_odd = (isignificand != 0) & _isodd(expval + bias)
match rnd:
case RoundMode.TowardZero:
should_round_away = xp.zeros_like(delta, dtype=xp.bool)
case RoundMode.TowardPositive:
should_round_away = ~is_negative & (delta > 0)
case RoundMode.TowardNegative:
should_round_away = is_negative & (delta > 0)
case RoundMode.TiesToAway:
should_round_away = delta >= 0.5
case RoundMode.TiesToEven:
should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd)
case RoundMode.Stochastic:
assert srbits is not None
## RTNE delta to srbits
d = delta * 2.0**srnumbits
floord = to_int(xp.floor(d))
dd = d - xp.floor(d)
should_round_away_tne = (dd > 0.5) | ((dd == 0.5) & _isodd(floord))
drnd = floord + xp.astype(should_round_away_tne, floord.dtype)
should_round_away = drnd + srbits >= 2**srnumbits
case RoundMode.StochasticOdd:
assert srbits is not None
## RTNO delta to srbits
d = delta * 2.0**srnumbits
floord = to_int(xp.floor(d))
dd = d - xp.floor(d)
should_round_away_tno = (dd > 0.5) | ((dd == 0.5) & ~_isodd(floord))
drnd = floord + xp.astype(should_round_away_tno, floord.dtype)
should_round_away = drnd + srbits >= 2**srnumbits
case RoundMode.StochasticFast:
assert srbits is not None
should_round_away = (
delta + to_float(2 * srbits + 1) * 2.0 ** -(1 + srnumbits) >= 1.0
)
case RoundMode.StochasticFastest:
assert srbits is not None
should_round_away = delta + to_float(srbits) * 2.0**-srnumbits >= 1.0
isignificand = xp.where(should_round_away, isignificand + 1, isignificand)
fresult = _ldexp(to_float(isignificand), expval)
result = xp.where(finite_nonzero, fresult, absv)
amax = xp.where(is_negative, -fi.min, fi.max)
if sat:
result = xp.where(result > amax, amax, result)
else:
match rnd:
case RoundMode.TowardNegative:
put_amax_at = (result > amax) & ~is_negative
case RoundMode.TowardPositive:
put_amax_at = (result > amax) & is_negative
case RoundMode.TowardZero:
put_amax_at = result > amax
case _:
put_amax_at = xp.zeros_like(result, dtype=xp.bool)
result = xp.where(finite_nonzero & put_amax_at, amax, result)
# Now anything larger than amax goes to infinity or NaN
if fi.has_infs:
result = xp.where(result > amax, xp.inf, result)
elif fi.num_nans > 0:
result = xp.where(result > amax, xp.nan, result)
else:
if xp.any(result > amax):
raise ValueError(f"No Infs or NaNs in format {fi}, and sat=False")
result = xp.where(is_negative, -result, result)
# Make negative zeros negative if has_nz, else make them not negative.
if fi.has_nz:
result = xp.where((result == 0) & is_negative, -0.0, result)
else:
result = xp.where(result == 0, 0.0, result)
return result