Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ This release upgrades the libtorch backend to v2.7.1, using CUDA 12.8.

# NuGet Version 0.105.1

__Breaking Changes__:

`torch.nn.functional.scaled_dot_product_attention`'s function signature has been changed. The `is_casual` argument has been renamed to `is_causal`.<br/>

Comment thread
DillionLowry marked this conversation as resolved.
Outdated
__Bug Fixes__:

#1426 Sequential.eval() does not put model into eval mode<br/>
Expand All @@ -45,6 +49,7 @@ __API Changes__:
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.<br/>
Returning an input tensor has been corrected, is now `alias()`.<br/>
Add `torchvision.transforms.Resize` `interpolation` and `antialias`.<br />
Add optional `scale` and `enable_gqa` arguments to `torch.nn.functional.scaled_dot_product_attention`.<br/>

# NuGet Version 0.105.0

Expand Down
5 changes: 3 additions & 2 deletions src/Native/LibTorchSharp/THSNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1064,9 +1064,10 @@ Tensor THSNN_unfold(const Tensor input, const int64_t kernel1, const int64_t ker
CATCH_TENSOR(torch::nn::functional::unfold(*input, opts));
}

Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual)
Tensor THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool is_causal, double* scale, bool enable_gqa)
{
auto mask = attention_mask == nullptr ? c10::nullopt : c10::optional<at::Tensor>(*attention_mask);
auto scl = (scale == nullptr) ? c10::nullopt : c10::optional<double>(*scale);

CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, casual));
CATCH_TENSOR(torch::scaled_dot_product_attention(*query, *key, *value, mask, p, is_causal, scl, enable_gqa));
}
2 changes: 1 addition & 1 deletion src/Native/LibTorchSharp/THSNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ EXPORT_API(Tensor) THSNN_cosine_similarity(const Tensor input1, const Tensor i

EXPORT_API(Tensor) THSNN_pairwise_distance(const Tensor input1, const Tensor input2, double p, double eps, bool keepdim);

EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool casual);
EXPORT_API(Tensor) THSNN_scaled_dot_product_attention(const Tensor query, const Tensor key, const Tensor value, const Tensor attention_mask, double p, bool is_causal, double* scale, bool enable_gqa);

// Initializers

Expand Down
22 changes: 16 additions & 6 deletions src/TorchSharp/NN/Transformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,25 @@ public static partial class functional
/// A float mask of the same type as query, key, value that is added to the attention score.
/// </param>
/// <param name="p">Dropout probability</param>
/// <param name="is_casual">If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.</param>
/// <param name="is_causal">If true, assumes causal attention masking and errors if both attn_mask and is_causal are set.</param>
/// <param name="scale">Scaling factor applied prior to softmax. If null, 1/sqrt(E) is used.</param>
/// <param name="enable_gqa">If true, enable Group Query Attention</param>
/// <returns></returns>
public static Tensor scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask = null, double p = 0.0, [MarshalAs(UnmanagedType.U1)] bool is_casual = false)
public static Tensor scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask = null, double p = 0.0, [MarshalAs(UnmanagedType.U1)] bool is_causal = false, double? scale = null, bool enable_gqa = false)
{
if (p < 0) throw new ArgumentException("Dropout probability must be greater than or equal to zero.");
if (is_casual && attn_mask is not null) throw new ArgumentException("Casual attention masking cannot pass a mask.");
var res = THSNN_scaled_dot_product_attention(query.Handle, key.Handle, value.Handle, attn_mask is null ? IntPtr.Zero : attn_mask.Handle, p, is_casual);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
if (is_causal && attn_mask is not null) throw new ArgumentException("Casual attention masking cannot pass a mask.");
if (query.dim() < 2 || key.dim() < 2 || value.dim() < 2) throw new ArgumentException("Query, key, and value must have at least 2 dimensions.");
if (!enable_gqa && (query.size(-3) != key.size(-3) || query.size(-3) != value.size(-3))) throw new InvalidOperationException("Query and key/value heads must be equal when Group Query Attention is not enabled.");

var _scale = scale.HasValue ? scale.Value : default;

unsafe {
double* scalePtr = scale.HasValue ? &_scale : null;
var res = THSNN_scaled_dot_product_attention(query.Handle, key.Handle, value.Handle, attn_mask is null ? IntPtr.Zero : attn_mask.Handle, p, is_causal, (IntPtr)scalePtr, enable_gqa);
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
return new Tensor(res);
}
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ internal static extern IntPtr THSNN_custom_module(
internal static extern IntPtr THSNN_local_response_norm(IntPtr input, long size, double alpha, double beta, double k);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool casual);
internal static extern IntPtr THSNN_scaled_dot_product_attention(IntPtr query, IntPtr key, IntPtr value, IntPtr attention_mask, double p, [MarshalAs(UnmanagedType.U1)] bool is_causal, IntPtr scale, [MarshalAs(UnmanagedType.U1)] bool enable_gqa);
}
#pragma warning restore CA2101
}
}
38 changes: 37 additions & 1 deletion test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5338,7 +5338,43 @@ public void TestScaledDotProductWithMask()
Assert.Equal(query.shape, x.shape);
Assert.Equal(value, x);

Assert.Throws<ArgumentException>(() => torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: mask, is_casual: true));
Assert.Throws<ArgumentException>(() => torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask: mask, is_causal: true));
}
[Fact]
public void TestScaledDotProductWithScale()
{

var query = torch.rand(32, 8, 128, 64) * 0.25;
var key = torch.rand(32, 8, 128, 64) * 0.5;
var value = torch.rand(32, 8, 128, 64) * 0.125;
var customScale = 0.5;

var defaultOutput = torch.nn.functional.scaled_dot_product_attention(query, key, value);
var withCustomScale = torch.nn.functional.scaled_dot_product_attention(query, key, value, scale: customScale);

Assert.Equal(query.shape, withCustomScale.shape);
Assert.False(torch.allclose(defaultOutput, withCustomScale, rtol: 1e-5, atol: 1e-5));
}

[Fact]
public void TestScaledDotProductWithGQA()
{
var batchSize = 2;
var queryHeads = 8;
var kvHeads = 2; // Key/value heads should be less than query heads for GQA
var seqLen = 16;
var headDim = 64;

var query = torch.ones(batchSize, queryHeads, seqLen, headDim) * 0.25;
var key = torch.ones(batchSize, kvHeads, seqLen, headDim) * 0.5;
var value = torch.ones(batchSize, kvHeads, seqLen, headDim) * 0.125;

Assert.Throws<InvalidOperationException>(() =>
torch.nn.functional.scaled_dot_product_attention(query, key, value, enable_gqa: false));

var output = torch.nn.functional.scaled_dot_product_attention(query, key, value, enable_gqa: true);

Assert.Equal(query.shape, output.shape);
}

[Fact]
Expand Down