Skip to content

Commit bf3f1e1

Browse files
Fix modules ignoring default device when created via native P/Invoke (#1438)
MoveModule<T>() was returning modules unchanged when device=null, instead of resolving to torch.get_default_device(). This caused Embedding, EmbeddingBag, LSTM, GRU, RNN, LSTMCell, GRUCell, RNNCell, and PReLU to always be created on CPU regardless of the default device setting. Fix MoveModule to use GetDefaultDeviceAndType() to resolve defaults, aligning behavior with managed modules like Linear and Conv2d. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent bdc2bcb commit bf3f1e1

2 files changed

Lines changed: 33 additions & 4 deletions

File tree

src/TorchSharp/NN/Module.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,10 +1138,8 @@ protected static (Device device, ScalarType dtype) GetDefaultDeviceAndType(Devic
11381138
internal T MoveModule<T>(Device? device, ScalarType? dtype) where T : Module
11391139
{
11401140
T module = (T)this;
1141-
1142-
return device != null ?
1143-
(dtype.HasValue ? (T)module._to(device, dtype.Value, false) : (T)module._to(device.type, device.index, false)) :
1144-
(dtype.HasValue ? (T)module._to(dtype.Value, false) : module);
1141+
var (targetDevice, targetDtype) = GetDefaultDeviceAndType(device, dtype);
1142+
return (T)module._to(targetDevice, targetDtype, false);
11451143
}
11461144

11471145
protected void ClearModules() { _internal_submodules.clear(); }

test/TorchSharpTest/NN.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5264,6 +5264,37 @@ public void TestEmbeddingBagFromPretrained()
52645264
}
52655265
}
52665266

5267+
[Fact]
5268+
public void TestEmbeddingDefaultDevice()
5269+
{
5270+
// Regression test for https://github.com/dotnet/TorchSharp/issues/1438
5271+
// Embedding (and other native modules) ignored torch.set_default_device(),
5272+
// causing device mismatch when called with tensors on the default device.
5273+
var defaultDevice = torch.get_default_device();
5274+
5275+
try {
5276+
torch.set_default_device(torch.META);
5277+
5278+
// Before the fix, these modules were created on CPU despite META being
5279+
// the default device, which would cause "Expected all tensors to be on
5280+
// the same device" errors when used with tensors on the default device.
5281+
using (var emb = Embedding(2, 3)) {
5282+
Assert.Equal(DeviceType.META, emb.weight!.device_type);
5283+
}
5284+
5285+
using (var emb = EmbeddingBag(2, 3)) {
5286+
Assert.Equal(DeviceType.META, emb.weight!.device_type);
5287+
}
5288+
5289+
// Verify explicit device still takes precedence over default
5290+
using (var emb = Embedding(2, 3, device: torch.CPU)) {
5291+
Assert.Equal(DeviceType.CPU, emb.weight!.device_type);
5292+
}
5293+
} finally {
5294+
torch.set_default_device(defaultDevice);
5295+
}
5296+
}
5297+
52675298
[Fact]
52685299
public void TestOneHotEncoding1()
52695300
{

0 commit comments

Comments
 (0)