diff --git a/src/TorchSharp/NN/Module.cs b/src/TorchSharp/NN/Module.cs index 498bda549..59eacee79 100644 --- a/src/TorchSharp/NN/Module.cs +++ b/src/TorchSharp/NN/Module.cs @@ -1138,10 +1138,8 @@ protected static (Device device, ScalarType dtype) GetDefaultDeviceAndType(Devic internal T MoveModule(Device? device, ScalarType? dtype) where T : Module { T module = (T)this; - - return device != null ? - (dtype.HasValue ? (T)module._to(device, dtype.Value, false) : (T)module._to(device.type, device.index, false)) : - (dtype.HasValue ? (T)module._to(dtype.Value, false) : module); + var (targetDevice, targetDtype) = GetDefaultDeviceAndType(device, dtype); + return (T)module._to(targetDevice, targetDtype, false); } protected void ClearModules() { _internal_submodules.clear(); } diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index 86e339d7f..f2ed50db3 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -5264,6 +5264,37 @@ public void TestEmbeddingBagFromPretrained() } } + [Fact] + public void TestEmbeddingDefaultDevice() + { + // Regression test for https://github.com/dotnet/TorchSharp/issues/1438 + // Embedding (and other native modules) ignored torch.set_default_device(), + // causing device mismatch when called with tensors on the default device. + var defaultDevice = torch.get_default_device(); + + try { + torch.set_default_device(torch.META); + + // Before the fix, these modules were created on CPU despite META being + // the default device, which would cause "Expected all tensors to be on + // the same device" errors when used with tensors on the default device. + using (var emb = Embedding(2, 3)) { + Assert.Equal(DeviceType.META, emb.weight!.device_type); + } + + using (var emb = EmbeddingBag(2, 3)) { + Assert.Equal(DeviceType.META, emb.weight!.device_type); + } + + // Verify explicit device still takes precedence over default + using (var emb = Embedding(2, 3, device: torch.CPU)) { + Assert.Equal(DeviceType.CPU, emb.weight!.device_type); + } + } finally { + torch.set_default_device(defaultDevice); + } + } + [Fact] public void TestOneHotEncoding1() {