Skip to content

Commit 5c81ff0

Browse files
ezhulenevtensorflower-gardener
authored andcommitted
PR tensorflow#37395: [xla:pjrt] Switch PjRt to xla/runtime ids: ProcessId, DeviceId, ChipId
Imported from GitHub PR openxla/xla#37395 Use the same strongly-typed ids across PjRt and XLA Copybara import of the project: -- a0e56cb461fbe90a35490611b4164d2d955d4985 by Eugene Zhulenev <ezhulenev@openxla.org>: [xla:pjrt] Switch PjRt to xla/runtime ids: ProcessId, DeviceId, ChipId Merging this change closes tensorflow#37395 PiperOrigin-RevId: 868190782
1 parent 4943e58 commit 5c81ff0

7 files changed

Lines changed: 110 additions & 17 deletions

File tree

third_party/xla/xla/backends/gpu/transforms/collectives/all_reduce_blueconnect.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,14 +82,14 @@ static std::optional<GlobalDeviceId> TryConvertingReplicaIdToDeviceId(
8282
// devices on different partitions.
8383
return std::nullopt;
8484
}
85-
return GlobalDeviceId{device_assignment(replica_id, /*computation_id=*/0)};
85+
return GlobalDeviceId(device_assignment(replica_id, /*computation_id=*/0));
8686
}
8787
if (collective_group_mode ==
8888
CollectiveOpGroupMode::COLLECTIVE_OP_GROUP_MODE_FLATTENED_ID) {
8989
int partition_count = device_assignment.computation_count();
9090
int64_t actual_replica_id = replica_id / partition_count;
9191
int64_t partition_id = replica_id % partition_count;
92-
return GlobalDeviceId{device_assignment(actual_replica_id, partition_id)};
92+
return GlobalDeviceId(device_assignment(actual_replica_id, partition_id));
9393
}
9494

9595
// COLLECTIVE_OP_GROUP_MODE_CROSS_PARTITION and

third_party/xla/xla/pjrt/BUILD

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,10 @@ cc_library(
554554
visibility = internal_visibility([":friends"]),
555555
deps = [
556556
"//xla/pjrt/proto:pjrt_value_type_proto_cc",
557-
"//xla/tsl/lib/gtl:int_type",
557+
"//xla/runtime:chip_id",
558+
"//xla/runtime:device_id",
559+
"//xla/runtime:process_id",
560+
"@com_google_absl//absl/base:core_headers",
558561
"@com_google_absl//absl/container:inlined_vector",
559562
],
560563
)

third_party/xla/xla/pjrt/pjrt_common.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@ limitations under the License.
2121
#include <variant>
2222
#include <vector>
2323

24+
#include "absl/base/macros.h"
2425
#include "absl/container/inlined_vector.h"
2526
#include "xla/pjrt/proto/pjrt_value_type.pb.h"
26-
#include "xla/tsl/lib/gtl/int_type.h"
27+
#include "xla/runtime/chip_id.h"
28+
#include "xla/runtime/device_id.h"
29+
#include "xla/runtime/process_id.h"
2730

2831
namespace xla {
2932

@@ -50,13 +53,11 @@ PjRtIdContainer<Id> MakeContinuousIds(int start, int size) {
5053
return container;
5154
}
5255

53-
// The strong-typed integer classes to better disambiguate different IDs for
54-
// PJRT devices.
55-
TSL_LIB_GTL_DEFINE_INT_TYPE(PjRtProcessId, int32_t);
56-
TSL_LIB_GTL_DEFINE_INT_TYPE(PjRtGlobalChipId, int32_t);
57-
TSL_LIB_GTL_DEFINE_INT_TYPE(PjRtGlobalDeviceId, int32_t);
58-
TSL_LIB_GTL_DEFINE_INT_TYPE(PjRtLocalDeviceId, int32_t);
59-
TSL_LIB_GTL_DEFINE_INT_TYPE(PjRtLocalHardwareId, int32_t);
56+
using PjRtProcessId ABSL_DEPRECATE_AND_INLINE() = ProcessId;
57+
using PjRtLocalDeviceId ABSL_DEPRECATE_AND_INLINE() = LocalDeviceId;
58+
using PjRtGlobalDeviceId ABSL_DEPRECATE_AND_INLINE() = GlobalDeviceId;
59+
using PjRtLocalHardwareId ABSL_DEPRECATE_AND_INLINE() = LocalChipId;
60+
using PjRtGlobalChipId ABSL_DEPRECATE_AND_INLINE() = GlobalChipId;
6061

6162
using PjRtPlatformId = uint64_t;
6263

third_party/xla/xla/runtime/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,17 @@ cc_library(
165165
],
166166
)
167167

168+
cc_library(
169+
name = "chip_id",
170+
hdrs = ["chip_id.h"],
171+
deps = [
172+
"//xla/tsl/lib/gtl:int_type",
173+
"@com_google_absl//absl/strings",
174+
"@com_google_absl//absl/strings:string_view",
175+
"@com_google_absl//absl/types:span",
176+
],
177+
)
178+
168179
cc_library(
169180
name = "process_id",
170181
hdrs = ["process_id.h"],
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/* Copyright 2026 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef XLA_RUNTIME_CHIP_ID_H_
17+
#define XLA_RUNTIME_CHIP_ID_H_
18+
19+
#include <cstddef>
20+
#include <cstdint>
21+
#include <string>
22+
23+
#include "absl/strings/str_cat.h"
24+
#include "absl/strings/str_join.h"
25+
#include "absl/strings/string_view.h"
26+
#include "absl/types/span.h"
27+
#include "xla/tsl/lib/gtl/int_type.h"
28+
29+
namespace xla {
30+
31+
// Some of the accelerator devices consist of multiple chips, and XLA might need
32+
// to address them separately. For example GB200 (1) from NVIDIA can be viewed
33+
// as a single device that consists of one Grace CPU and two Blackwell GPUs (one
34+
// `LocalDeviceId` with two `LocalChipId`s). Some TPU chips from Google have two
35+
// tensor cores (2) that appear as a single device to JAX/XLA users, and XLA
36+
// (including PjRt runtime) need to address these chips separately.
37+
//
38+
// (1) https://www.nvidia.com/en-us/data-center/gb200-nvl72/
39+
// (2)
40+
// https://docs.jax.dev/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration
41+
42+
// Strongly-typed integer type for identifying local chips that belong to one of
43+
// the local devices.
44+
TSL_LIB_GTL_DEFINE_INT_TYPE(LocalChipId, int32_t);
45+
46+
// Strongly-typed integer type for identifying global chips in a distributed
47+
// execution.
48+
TSL_LIB_GTL_DEFINE_INT_TYPE(GlobalChipId, int32_t);
49+
50+
template <typename Sink>
51+
void AbslStringify(Sink& sink, LocalChipId id) {
52+
absl::Format(&sink, "%d", id.value());
53+
}
54+
55+
template <typename Sink>
56+
void AbslStringify(Sink& sink, GlobalChipId id) {
57+
absl::Format(&sink, "%d", id.value());
58+
}
59+
60+
// StrJoin for global chip ids that shortens long list of ids for readability.
61+
//
62+
// It is not uncommon to see in XLA a list of global chips with more than 1k
63+
// of entries. We don't need to print them all to get a human readable list
64+
// of chips for logging and debugging.
65+
inline std::string HumanReadableChips(absl::Span<const GlobalChipId> chips,
66+
absl::string_view separator = ",",
67+
size_t first = 8, size_t last = 2) {
68+
if (chips.size() > first + last) {
69+
return absl::StrCat(absl::StrJoin(chips.first(first), separator), "...",
70+
absl::StrJoin(chips.last(last), separator));
71+
}
72+
return absl::StrJoin(chips, separator);
73+
}
74+
75+
} // namespace xla
76+
77+
#endif // XLA_RUNTIME_CHIP_ID_H_

third_party/xla/xla/runtime/device_id.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ namespace xla {
3333
// system. XLA doesn't have a strong opinion about what global numbering scheme
3434
// is applied to GPUs; the user must provide a local -> global mapping via
3535
// GpuExecutableRunOptions for the local GPUs.
36-
TSL_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int64_t);
37-
TSL_LIB_GTL_DEFINE_INT_TYPE(LocalDeviceId, int64_t);
36+
TSL_LIB_GTL_DEFINE_INT_TYPE(GlobalDeviceId, int32_t);
37+
TSL_LIB_GTL_DEFINE_INT_TYPE(LocalDeviceId, int32_t);
3838

3939
using ::tsl::IncarnationId; // NOLINT(misc-unused-using-decls)
4040

@@ -48,7 +48,8 @@ void AbslStringify(Sink& sink, LocalDeviceId id) {
4848
absl::Format(&sink, "%d", id.value());
4949
}
5050

51-
// StrJoin for global devices that shortens long list of devices for readbility.
51+
// StrJoin for global devices that shortens long list of devices for
52+
// readability.
5253
//
5354
// It is not uncommon to see in XLA a list of global devices with more than 1k
5455
// of entries. We don't need to print them all to get a human readable list

third_party/xla/xla/runtime/process_id.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,18 @@ namespace xla {
3131
// Strongly-typed integer type for identifying processes in a distributed
3232
// execution. Processes can belong to the same physical host, or can be
3333
// distributed over multiple nodes.
34-
TSL_LIB_GTL_DEFINE_INT_TYPE(ProcessId, int64_t);
34+
TSL_LIB_GTL_DEFINE_INT_TYPE(ProcessId, int32_t);
3535

3636
template <typename Sink>
3737
void AbslStringify(Sink& sink, ProcessId id) {
3838
absl::Format(&sink, "%d", id.value());
3939
}
4040

41-
// StrJoin for processes that shortens long list of processes for readbility.
41+
// StrJoin for processes that shortens long list of processes for readability.
4242
//
4343
// It is not uncommon to see in XLA a list of processes with more than 1k
4444
// of entries. We don't need to print them all to get a human readable list
45-
// of proceses for logging and debugging.
45+
// of processes for logging and debugging.
4646
inline std::string HumanReadableProcesses(absl::Span<const ProcessId> processes,
4747
absl::string_view separator = ",",
4848
size_t first = 10, size_t last = 4) {

0 commit comments

Comments
 (0)