-
Notifications
You must be signed in to change notification settings - Fork 110
Expand file tree
/
Copy pathdevice_event.py
More file actions
136 lines (111 loc) · 4.52 KB
/
device_event.py
File metadata and controls
136 lines (111 loc) · 4.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import infinicore._device
from infinicore.lib import _infinicore
class DeviceEvent:
"""A device event for timing operations and synchronization across devices.
Similar to torch.cuda.Event, this class provides functionality to:
- Record events on specific device streams
- Synchronize with events
- Measure elapsed time between events
- Query event completion status
- Make streams wait for events
Args:
enable_timing: Whether the event should record timing data. Default: False.
blocking: Whether to use blocking synchronization. Default: False.
interprocess: Whether the event can be used for inter-process communication. Default: False.
external: Whether the event is an external event. Default: False.
device: Target device for this event. If None, uses current device.
"""
def __init__(self, enable_timing=False, device=None):
# Build flags based on parameters
flags = 0
if not enable_timing:
flags |= 0x2 # DISABLE_TIMING
# if blocking:
# flags |= 0x1 # BLOCKING_SYNC
# Store parameters for reference
self._enable_timing = enable_timing
# self._blocking = blocking
# self._interprocess = interprocess
# self._external = external
if device is None:
# Use current device
if flags == 0:
self._underlying = _infinicore.DeviceEvent()
else:
self._underlying = _infinicore.DeviceEvent(flags)
elif flags == 0:
# Construct with device only
self._underlying = _infinicore.DeviceEvent(device._underlying)
else:
# Construct with both device and flags
self._underlying = _infinicore.DeviceEvent(device._underlying, flags)
def record(self, stream=None):
"""Record the event.
Args:
stream: Stream to record the event on. If None, uses current stream.
"""
if stream is None:
self._underlying.record()
else:
self._underlying.record(stream)
def synchronize(self):
"""Wait for the event to complete (blocking)."""
self._underlying.synchronize()
def query(self):
"""Check if the event has been completed.
Returns:
bool: True if completed, False otherwise.
"""
return self._underlying.query()
def elapsed_time(self, other):
"""Calculate elapsed time between this event and another event.
Args:
other: The other DeviceEvent to compare with
Returns:
float: Elapsed time in milliseconds between this event and the other event
Raises:
RuntimeError: If events are on different devices or not recorded,
or if timing is disabled on either event
"""
if not self._enable_timing or not other._enable_timing:
raise RuntimeError("Cannot measure elapsed time when timing is disabled")
return self._underlying.elapsed_time(other._underlying)
def wait(self, stream=None):
"""Make a stream wait for this event to complete.
Args:
stream: Stream to make wait for this event. If None, uses current stream.
"""
self._underlying.wait(stream)
@property
def device(self):
"""Get the device where this event was created."""
return infinicore.device._from_infinicore_device(self._underlying.device)
@property
def is_recorded(self):
"""Check if the event has been recorded."""
return self._underlying.is_recorded
@property
def enable_timing(self):
"""Whether this event records timing data."""
return self._enable_timing
@property
def blocking(self):
"""Whether this event uses blocking synchronization."""
return self._blocking
@property
def interprocess(self):
"""Whether this event can be used for inter-process communication."""
return self._interprocess
def __repr__(self):
flags_str = []
if not self._enable_timing:
flags_str.append("timing_disabled")
if self._blocking:
flags_str.append("blocking")
if self._interprocess:
flags_str.append("interprocess")
if self._external:
flags_str.append("external")
if not flags_str:
flags_str.append("default")
return f"DeviceEvent(device={self.device}, flags={', '.join(flags_str)}, recorded={self.is_recorded})"