We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 6837b28 + 9cc12f0 commit 72c38d5Copy full SHA for 72c38d5
1 file changed
src/graphnet/models/components/embedding.py
@@ -84,7 +84,6 @@ def __init__(
84
super().__init__()
85
86
self.sin_emb = SinusoidalPosEmb(dim=seq_length, scaled=scaled)
87
- self.aux_emb = nn.Embedding(2, seq_length // 2)
88
self.sin_emb2 = SinusoidalPosEmb(dim=seq_length // 2, scaled=scaled)
89
90
if n_features < 4:
@@ -93,7 +92,7 @@ def __init__(
93
92
f"{n_features} features."
94
)
95
elif n_features >= 6:
96
-
+ self.aux_emb = nn.Embedding(2, seq_length // 2)
97
hidden_dim = 6 * seq_length
98
else:
99
hidden_dim = int((n_features + 0.5) * seq_length)
0 commit comments