-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy pathembedding.py
More file actions
32 lines (26 loc) · 900 Bytes
/
embedding.py
File metadata and controls
32 lines (26 loc) · 900 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor
def embedding(
input: Tensor,
weight: Tensor,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
*,
out=None,
) -> Tensor:
r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size."""
assert (
(padding_idx is None)
and (max_norm is None)
and (scale_grad_by_freq is False)
and (sparse is False)
), "Unsupported parameters."
# Note: embedding now supports device-side input for graph recording
# The C++ implementation handles both CPU and device-side inputs
if out is None:
return Tensor(_infinicore.embedding(input._underlying, weight._underlying))
_infinicore.embedding_(out._underlying, input._underlying, weight._underlying)
return out