# Copyright (c) Meta Platforms, Inc. and affiliates
import shutil
import tempfile
from functools import wraps
from typing import Any, Callable, Dict, Optional, Tuple
import torch.distributed as dist
def with_temp_dir(
func: Optional[Callable] = None,
) -> Optional[Callable]:
"""
Wrapper to initialize temp directory for distributed checkpoint.
"""
assert func is not None
@wraps(func)
def wrapper(self, *args: Tuple[object], **kwargs: Dict[str, Any]) -> None:
# Only create temp_dir when rank is 0
if dist.get_rank() == 0:
temp_dir = tempfile.mkdtemp()
print(f"Using temp directory: {temp_dir}")
else:
temp_dir = ""
object_list = [temp_dir]
# Broadcast temp_dir to all the other ranks
dist.broadcast_object_list(object_list)
self.temp_dir = object_list[0]
try:
func(self)
finally:
if dist.get_rank() == 0:
shutil.rmtree(self.temp_dir, ignore_errors=True)
return wrapper