Skip to content

Commit 311249a

Browse files
BlaziusMaximusOrbax Authors
authored andcommitted
Add CNS tests for partial saving. Also modify async file writes to unlink before writing.
PiperOrigin-RevId: 874287343
1 parent eb178fb commit 311249a

1 file changed

Lines changed: 18 additions & 2 deletions

File tree

checkpoint/orbax/checkpoint/_src/path/async_path.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,31 @@ def _mkdir_sync(**thread_kwargs):
5050

5151

5252
async def write_bytes(path: epath.Path, data: Any) -> int:
53-
return await asyncio.to_thread(path.write_bytes, data)
53+
54+
def _write():
55+
try:
56+
path.unlink()
57+
except OSError:
58+
pass
59+
return path.write_bytes(data)
60+
61+
return await asyncio.to_thread(_write)
5462

5563

5664
async def read_bytes(path: epath.Path) -> bytes:
5765
return await asyncio.to_thread(path.read_bytes)
5866

5967

6068
async def write_text(path: epath.Path, text: str) -> int:
61-
return await asyncio.to_thread(path.write_text, text)
69+
70+
def _write():
71+
try:
72+
path.unlink()
73+
except OSError:
74+
pass
75+
return path.write_text(text)
76+
77+
return await asyncio.to_thread(_write)
6278

6379

6480
async def read_text(path: epath.Path) -> str:

0 commit comments

Comments
 (0)