-
Notifications
You must be signed in to change notification settings - Fork 719
Expand file tree
/
Copy pathquantized_tensor.py
More file actions
615 lines (504 loc) · 22 KB
/
quantized_tensor.py
File metadata and controls
615 lines (504 loc) · 22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Pure Python base classes for quantization."""
from __future__ import annotations
from typing import Optional, Tuple, Iterable, Any, Dict, Union
import abc
import warnings
import math
import torch
from torch.utils._pytree import tree_map
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor._quantization_helpers import (
_QuantizeFunc,
_IdentityFunc,
_stride_from_shape,
)
class QuantizedTensorStorage:
r"""Base class for all TensorStorage classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
contained inside torch.autograd function and not visible to
PyTorch's autograd).
When creating a new tensor type X one should create both
XTensorStorage class inheriting from QuantizedTensorStorage and
XTensor inheriting from XTensorStorage and QuantizedTensor.
XTensorStorage should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (like __torch_dispatch__)."""
_quantizer: Optional[Quantizer]
def update_usage(
self,
rowwise_usage: Optional[bool] = None,
columnwise_usage: Optional[bool] = None,
):
r"""
Generate or remove quantized data based on provided usage.
Parameters
----------
rowwise_usage : Optional[bool[, default = None
Whether to create or keep the data needed for using the tensor
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
preserves the original value in the tensor.
columnwise_usage : Optional[bool], default = None
Whether to create or keep the data needed for using the tensor
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
`None` preserves the original value in the tensor.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_usage function"
)
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement get_usages function"
)
def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]:
"""Prepare the tensor base for saving for backward"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement prepare_for_saving function"
)
def restore_from_saved(
self, tensors: list[Optional[torch.Tensor]]
) -> list[Optional[torch.Tensor]]:
"""Restore the tensor base data from the saved tensors list"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement restore_from_saved function"
)
def _get_quantizer(self) -> Quantizer:
"""Get builder for quantized tensor
Quantizer can be used for in-place operations.
"""
if self._quantizer is not None:
return self._quantizer
return self._build_default_quantizer()
def _build_default_quantizer(self) -> Quantizer:
"""Build default quantizer for the tensor"""
raise ValueError(
f"{self.__class__.__name__} has no quantizer "
"and no default quantizer is available defined in the subclass."
)
def quantize_(
self, tensor: torch.Tensor, *, noop_flag: Optional[torch.Tensor] = None
) -> QuantizedTensor:
"""Quantize tensor in-place"""
self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag)
return self
def update_quantizer(self, quantizer: Quantizer):
"""Update quantizer for the tensor"""
if self._quantizer is None:
raise RuntimeError("To be updated, quantizer must be set")
if self._quantizer is not quantizer:
warnings.warn("Quantizer is being updated, this may affect model behavior")
self._quantizer = quantizer
def prepare_for_saving(
*tensors: Union[torch.Tensor, QuantizedTensorStorage],
) -> Tuple[
list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]]
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorStorage types too."""
tensor_list, tensor_objects_list = [], []
for tensor in tensors:
if tensor is None or isinstance(tensor, torch.Tensor):
tensor_list.append(tensor)
tensor_objects_list.append(None)
else:
t, t_obj = tensor.prepare_for_saving()
tensor_list.extend(t)
tensor_objects_list.append(t_obj)
return tensor_list, tensor_objects_list
def restore_from_saved(
tensors: list[Optional[Union[torch.Tensor, QuantizedTensorStorage]]],
saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]],
return_saved_tensors: bool = False,
) -> (
list[Optional[torch.Tensor | QuantizedTensorStorage]]
| tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]]
):
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects = []
for tensor in tensors:
if tensor is None or isinstance(tensor, torch.Tensor):
tensor_objects.append(saved_tensors[0])
saved_tensors = saved_tensors[1:]
else:
saved_tensors = tensor.restore_from_saved(saved_tensors)
tensor_objects.append(tensor)
if return_saved_tensors:
return tensor_objects, saved_tensors
return tensor_objects
class Quantizer(abc.ABC):
"""Builder class for quantized tensors.
This class is typically used to convert a high-precision tensor
(e.g. in FP32 or BF16) into a quantized tensor (e.g. in FP8).
"""
"""Whether to construct quantized tensors with "row-wise usage"
Hand-wave explanation: Consider the matrix multiplication C = A *
B^T (used in linear forward). Tensor Cores prefer "TN GEMMs" (in
Fortran-style column-major order), so A and B should be in
row-major order.
"""
rowwise_usage: bool
"""Whether to construct quantized tensors with "column-wise usage"
Hand-wave explanation: Consider the matrix multiplication C = A^T
* B (used in linear backward wgrad). Tensor Cores prefer "TN
GEMMs" (in Fortran-style column-major order), so A and B should be
in column-major order.
"""
columnwise_usage: bool
"""Whether to instantiates tensor for purely internal usage
Internal tensors are storage classes with minimal logic. They have
less overhead than PyTorch tensor sub-classes, but are not
compatible with PyTorch's autograd infrastructure nor PyTorch
operations.
"""
internal: bool
"""Whether to solely optimize for matrix multiplication
The resulting quantized tensors are not guaranteed to support any
operation other than matrix multiplication. Use with care since
this is likely to break communication, checkpointing, and many
other features.
"""
optimize_for_gemm: bool
def __init__(self, *, rowwise: bool, columnwise: bool) -> None:
self.rowwise_usage = rowwise
self.columnwise_usage = columnwise
self.internal = False
self.optimize_for_gemm = False
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"rowwise_usage={self.rowwise_usage}, "
f"columnwise_usage={self.columnwise_usage}, "
f"internal={self.internal}, "
")"
)
def update_quantized(
self,
src: torch.Tensor,
dst: QuantizedTensor,
*,
noop_flag: Optional[torch.Tensor] = None,
) -> QuantizedTensor:
"""Quantize tensor in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement update_quantized"
)
def quantize(
self,
tensor: torch.Tensor,
*,
out: Optional[QuantizedTensor] = None,
dtype: Optional[torch.dtype] = None, # pylint: disable=unused-argument # used by override
) -> QuantizedTensor:
"""Quantize tensor"""
if out is not None:
return self.update_quantized(tensor, out)
if (not self.internal) and torch.is_grad_enabled():
return _QuantizeFunc.apply(tensor, self.quantize_impl)
return _QuantizeFunc.forward(None, tensor, self.quantize_impl)
def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor implementation"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement quantize_impl function"
)
def multi_quantize(self, list_of_tensors):
"""Quantize multiple tensors"""
list_of_output_tensors = []
for tensor in list_of_tensors:
list_of_output_tensors.append(self.quantize(tensor))
return list_of_output_tensors
def __call__(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Quantize tensor"""
return self.quantize(tensor)
def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[Union[torch.device, str]] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> QuantizedTensor:
"""Construct quantized tensor with uninitialized data"""
if device is None:
device = torch.device("cuda")
# Handle the device passed as string
device = torch.device(device)
result = tex.create_empty_quantized_tensor(
self,
list(shape),
dtype,
device,
pin_memory,
)
if requires_grad:
result.requires_grad_(True)
return result
def calibrate(self, tensor: torch.Tensor) -> None:
"""Calibrate quantizer state
Updates quantization state as if quantizing a tensor, but
without actually performing the quantization.
"""
def set_usage(
self, *, rowwise: Optional[bool] = None, columnwise: Optional[bool] = None
) -> None:
"""Set how the quantized tensor is expected to be used
See documentation for `rowwise_usage` and `columnwise_usage`
variables.
"""
if rowwise is not None:
self.rowwise_usage = rowwise
if columnwise is not None:
self.columnwise_usage = columnwise
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Symbolic function for ONNX export"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement onnx_quantize"
)
def onnx_dequantize(self, tensor) -> torch.Tensor:
"""Symbolic function for ONNX export"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement onnx_dequantize"
)
def _get_compatible_recipe(self) -> Union[type[Recipe], None]:
"""Returns recipe class that is compatible with this quantizer"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement _get_compatible_recipe"
)
def supports_only_rowwise_all_gather(self) -> bool:
"""Returns True if the quantizer supports only rowwise all-gather"""
return False
def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument
"""Whether tensor supports quantized all-gather
Consider a less misleading function name.
"""
return True
def get_usages(self) -> Dict[str, bool]:
"""Get the usage of the quantizer"""
return {
"rowwise": self.rowwise_usage,
"columnwise": self.columnwise_usage,
}
class QuantizedTensor(torch.Tensor):
"""Abstract base class for tensor with quantized data
This is a proxy class with the interface of a standard PyTorch
tensor, but with data that has been encoded with some quantization
scheme. Derived classes should implement the quantization scheme
by overriding the `quantize_` and `dequantize` functions.
"""
def __new__(
cls,
shape: Iterable[int],
dtype: torch.dtype,
*,
requires_grad: bool = False,
device: Optional[torch.device] = None,
):
# We are assuming only contiguous tensors
stride = _stride_from_shape(shape)
instance = torch.Tensor._make_wrapper_subclass(
cls,
shape,
strides=stride,
storage_offset=0,
dtype=dtype,
layout=torch.strided,
requires_grad=requires_grad,
device=torch.cuda.current_device() if device is None else device,
)
return instance
def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Convert quantized data to standard PyTorch tensor"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement dequantize function"
)
def quantize_(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Update quantized data in-place"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement quantize_ function"
)
def detach(self) -> QuantizedTensor:
"""Create new quantized tensor with same data
Output tensor must be detached from the current autograd
graph.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement detach function"
)
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement clear function"
)
def __repr__(self, *, tensor_contents=None) -> str:
return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})"
def float(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return self.dequantize(dtype=torch.float32)
def bfloat16(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return self.dequantize(dtype=torch.bfloat16)
def half(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return self.dequantize(dtype=torch.float16)
def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return self.dequantize().cpu(memory_format=memory_format)
def expand_as(self, other: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring
if other is self:
# Note: expand_as is hackily used to create dummy autograd nodes
# and access the backward graph (see
# https://github.com/pytorch/pytorch/blob/238fb660851268f44ff88127887041fea352fe48/torch/nn/parallel/distributed.py#L1026).
# We hackily add a dummy function to handle this case.
return _IdentityFunc.apply(self)
return super().expand_as(other)
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs=None):
# Detach op
if func == torch.ops.aten.detach.default:
return args[0].detach()
# In-place copy op
if func == torch.ops.aten.copy_.default:
dst = args[0]
src = args[1]
if (
isinstance(dst, QuantizedTensor)
and isinstance(src, QuantizedTensor)
and type(dst._quantizer) is type(src._quantizer)
and set(src.get_usages().keys()) == set(dst.get_usages().keys())
and all(
src.get_usages()[usage] == dst.get_usages()[usage]
for usage in src.get_usages().keys()
)
):
dst_tensors, dst_tensor_obj = dst.prepare_for_saving()
src_tensors, src_tensor_obj = src.prepare_for_saving()
for dst_tensor, src_tensor in zip(dst_tensors, src_tensors):
if dst_tensor is not None:
dst_tensor.copy_(src_tensor, *args[2:], **kwargs)
dst_tensor_obj.restore_from_saved(dst_tensors)
src_tensor_obj.restore_from_saved(src_tensors)
return None
if isinstance(dst, QuantizedTensor):
dst.quantize_(src)
else:
if isinstance(src, QuantizedTensor):
src = src.dequantize()
dst.copy_(src)
return None
# View op
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
# Empty like op
if func == torch.ops.aten.empty_like.default:
tensor = args[0]
device = kwargs.get("device", tensor.device)
requires_grad = kwargs.get("requires_grad", tensor.requires_grad)
pin_memory = kwargs.get("pin_memory", False)
usage = tensor.get_usages()
quantizer_usage = tensor._quantizer.get_usages()
tensor._quantizer.set_usage(**usage)
out = tensor._quantizer.make_empty(
shape=tensor.shape,
dtype=tensor.dtype,
device=device,
requires_grad=requires_grad,
pin_memory=pin_memory,
)
tensor._quantizer.set_usage(**quantizer_usage)
return out
if func == torch.ops.aten.numel.default:
tensor = args[0]
return math.prod(tensor.size())
if func == torch.ops.aten.is_pinned.default:
tensor = args[0]
for t in tensor.get_data_tensors():
if t is not None:
return func(t)
return False # Or error out?
def maybe_unwrap(arg):
if isinstance(arg, QuantizedTensor):
return arg.dequantize(dtype=arg.dtype)
return arg
def maybe_update_inplace(arg, new_arg, schema_arg):
if (
isinstance(arg, QuantizedTensor)
and isinstance(new_arg, torch.Tensor)
and hasattr(schema_arg, "alias_info")
and hasattr(schema_arg.alias_info, "is_write")
and schema_arg.alias_info.is_write
):
arg.quantize_(new_arg)
elif isinstance(arg, list) and isinstance(new_arg, list):
# Recursively handle update for lists of tensors
for a, na in zip(arg, new_arg):
maybe_update_inplace(a, na, schema_arg)
# In-place op: dequantize, perform op, and quantize
if func._schema.is_mutable:
new_args = tree_map(maybe_unwrap, args)
new_kwargs = tree_map(maybe_unwrap, kwargs)
schema_args = func._schema.arguments
args_len = len(args)
super().__torch_dispatch__(func, types, new_args, new_kwargs)
for arg, new_arg, schema_arg in zip(args, new_args, schema_args):
maybe_update_inplace(arg, new_arg, schema_arg)
for kwarg, new_kwarg, schema_arg in zip(kwargs, new_kwargs, schema_args[args_len:]):
assert kwarg == new_kwarg == schema_arg.name, "name of the kw argument should match"
maybe_update_inplace(kwargs[kwarg], new_kwargs[new_kwarg], schema_arg)
return None
# Default op: dequantize and perform op
args = tree_map(maybe_unwrap, args)
if kwargs is not None:
kwargs = tree_map(maybe_unwrap, kwargs)
out = super().__torch_dispatch__(func, types, args, kwargs)
return out
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
# Do not force the QuantizedTensor type on the returned tensor
return torch._C._disabled_torch_function_impl(func, types, args, kwargs)
def contiguous(
self, memory_format: torch.memory_format = torch.contiguous_format
) -> QuantizedTensor:
# pylint: disable=missing-function-docstring
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement contiguous function"
)
def get_metadata(self) -> Dict[str, Any]:
"""Get keyword arguments for quantized tensor constructor
Contains metadata so that the new quantized tensor has the
same underlying quantized data.
"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement get_metadata function"
)
@classmethod
def make_like(
cls,
tensor: QuantizedTensor,
*,
shape: Optional[Iterable[int]] = None,
dtype: Optional[torch.dtype] = None,
requires_grad: bool = False,
) -> QuantizedTensor:
"""Create new quantized tensor
By default, new tensor has the same attributes and underlying
data. This function is intended to create view of tensors.
"""
shape = shape if shape is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata()
return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs)
def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor:
"""Create `QuantizedTensor` with given nominal dtype
The new tensor has the same underlying data.
"""
return self.__class__.make_like(self, dtype=dtype)