Skip to content

Commit

Permalink
Optimizations + logging
Browse files Browse the repository at this point in the history
  • Loading branch information
gspschmid committed Jul 5, 2024
1 parent 3767297 commit d78c9e8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
38 changes: 36 additions & 2 deletions checkpoint/orbax/checkpoint/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,40 @@ async def release_bytes(self, requested_bytes):
self._cv.notify_all()


def report_save_stats(primary_host, saving_replica_id, shard):
process_index = multihost.process_index()
is_primary_host = process_index == primary_host
is_saving_replica = shard.replica_id == saving_replica_id

if is_saving_replica:
bytes_written = shard.data.itemsize * shard.data.size
else:
bytes_written = 0
print(f'SAVE-STATS {process_index},{is_primary_host},{shard.device.id},{shard.replica_id},{is_saving_replica},{bytes_written}')


async def transfer_shard_to_host(shard) -> np.ndarray:
data = shard.data
has_pinned_host = any(
m.kind == "pinned_host" for m in shard.device.addressable_memories())
if jax.config.jax_enable_memories and has_pinned_host:
# If available, transfer to pinned host memory
sharding = jax.sharding.SingleDeviceSharding(shard.device,
memory_kind="pinned_host")
data = jax.device_put(data, sharding)
# Allow other transfers to be scheduled simultaneously
await asyncio.sleep(0)
return np.array(data, copy=False)


async def async_serialize(
arr_inp,
tensorstore_spec,
commit_future=None,
context=TS_CONTEXT,
primary_host: Optional[int] = 0,
primary_host: int | None = 0,
replica_id: int = 0,
transaction: Optional[ts.Transaction] = None,
):
"""Serialize an array using TensorStore.
Expand All @@ -199,6 +226,8 @@ async def async_serialize(
unless you are sure you know what you are doing.
replica_id: Allows overriding the shard replica id that will be saved.
DO NOT USE unless you are sure you know what you are doing.
transaction: TensorStore transaction to use for opening and writing the
array. If not specified, a non-transactional write will be used.
"""
if (
isinstance(arr_inp, jax.Array)
Expand Down Expand Up @@ -228,6 +257,7 @@ async def async_serialize(
create=True,
open=True,
context=context,
transaction=transaction,
)
# Asynchronous case.
if commit_future is not None:
Expand All @@ -247,11 +277,15 @@ async def async_serialize(
open=True,
assume_metadata=True,
context=context,
transaction=transaction,
)

async def _write_array(shard):
report_save_stats(primary_host, replica_id, shard)
if shard.replica_id == replica_id:
write_future = t[shard.index].write(shard.data)
data = await transfer_shard_to_host(shard)
# write_future = t[shard.index].write(data)
write_future = t[shard.index].write(data, can_reference_source_data_indefinitely=True)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(write_future.commit)
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,7 @@ async def serialize(
context=ts_context,
primary_host=self._primary_host,
replica_id=replica_id,
transaction=txn,
)
]
else:
Expand Down

0 comments on commit d78c9e8

Please sign in to comment.