Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/TorchSharp/NN/Module.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1138,10 +1138,8 @@ protected static (Device device, ScalarType dtype) GetDefaultDeviceAndType(Devic
internal T MoveModule<T>(Device? device, ScalarType? dtype) where T : Module
{
T module = (T)this;

return device != null ?
(dtype.HasValue ? (T)module._to(device, dtype.Value, false) : (T)module._to(device.type, device.index, false)) :
(dtype.HasValue ? (T)module._to(dtype.Value, false) : module);
var (targetDevice, targetDtype) = GetDefaultDeviceAndType(device, dtype);
return (T)module._to(targetDevice, targetDtype, false);
}

protected void ClearModules() { _internal_submodules.clear(); }
Expand Down
31 changes: 31 additions & 0 deletions test/TorchSharpTest/NN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5264,6 +5264,37 @@ public void TestEmbeddingBagFromPretrained()
}
}

[Fact]
public void TestEmbeddingDefaultDevice()
{
// Regression test for https://github.com/dotnet/TorchSharp/issues/1438
// Embedding (and other native modules) ignored torch.set_default_device(),
// causing device mismatch when called with tensors on the default device.
var defaultDevice = torch.get_default_device();

try {
torch.set_default_device(torch.META);

// Before the fix, these modules were created on CPU despite META being
// the default device, which would cause "Expected all tensors to be on
// the same device" errors when used with tensors on the default device.
using (var emb = Embedding(2, 3)) {
Assert.Equal(DeviceType.META, emb.weight!.device_type);
}

using (var emb = EmbeddingBag(2, 3)) {
Assert.Equal(DeviceType.META, emb.weight!.device_type);
}

// Verify explicit device still takes precedence over default
using (var emb = Embedding(2, 3, device: torch.CPU)) {
Assert.Equal(DeviceType.CPU, emb.weight!.device_type);
}
} finally {
torch.set_default_device(defaultDevice);
}
}

[Fact]
public void TestOneHotEncoding1()
{
Expand Down
Loading