From 16b6b0b84ace53420642bc728f74a1de75042785 Mon Sep 17 00:00:00 2001 From: alinpahontu2912 Date: Mon, 16 Feb 2026 20:37:30 +0100 Subject: [PATCH] Fix modules ignoring default device when created via native P/Invoke (#1438) MoveModule() was returning modules unchanged when device=null, instead of resolving to torch.get_default_device(). This caused Embedding, EmbeddingBag, LSTM, GRU, RNN, LSTMCell, GRUCell, RNNCell, and PReLU to always be created on CPU regardless of the default device setting. Fix MoveModule to use GetDefaultDeviceAndType() to resolve defaults, aligning behavior with managed modules like Linear and Conv2d. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/TorchSharp/NN/Module.cs | 6 ++---- test/TorchSharpTest/NN.cs | 31 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 498bda549..59eacee79 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -1138,10 +1138,8 @@ protected static (Device device, ScalarType dtype) GetDefaultDeviceAndType(Devic internal T MoveModule(Device? device, ScalarType? dtype) where T : Module { T module = (T)this; - - return device != null ? - (dtype.HasValue ? (T)module._to(device, dtype.Value, false) : (T)module._to(device.type, device.index, false)) : - (dtype.HasValue ? (T)module._to(dtype.Value, false) : module); + var (targetDevice, targetDtype) = GetDefaultDeviceAndType(device, dtype); + return (T)module._to(targetDevice, targetDtype, false); } protected void ClearModules() { _internal_submodules.clear(); } diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index 86e339d7f..f2ed50db3 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -5264,6 +5264,37 @@ public void TestEmbeddingBagFromPretrained() } } + [Fact] + public void TestEmbeddingDefaultDevice() + { + // Regression test for https://github.com/dotnet/TorchSharp/issues/1438 + // Embedding (and other native modules) ignored torch.set_default_device(), + // causing device mismatch when called with tensors on the default device. + var defaultDevice = torch.get_default_device(); + + try { + torch.set_default_device(torch.META); + + // Before the fix, these modules were created on CPU despite META being + // the default device, which would cause "Expected all tensors to be on + // the same device" errors when used with tensors on the default device. + using (var emb = Embedding(2, 3)) { + Assert.Equal(DeviceType.META, emb.weight!.device_type); + } + + using (var emb = EmbeddingBag(2, 3)) { + Assert.Equal(DeviceType.META, emb.weight!.device_type); + } + + // Verify explicit device still takes precedence over default + using (var emb = Embedding(2, 3, device: torch.CPU)) { + Assert.Equal(DeviceType.CPU, emb.weight!.device_type); + } + } finally { + torch.set_default_device(defaultDevice); + } + } + [Fact] public void TestOneHotEncoding1() {