Repository URL to install this package:
|
Version:
4.0.7 ▾
|
import functools
import gc
import logging
import os
import typing as t
import jax
import psutil
from jax._src.lax import control_flow
logger = logging.getLogger(__name__)
stats = [0]
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
def fmt_bytes(n: int) -> str:
K = 1024
sizes = ["B", "KB", "MB", "GB", "TB"]
if n < 0:
n = -n
k = 1
for size in sizes:
k2 = K * k
if n < k2:
p = n // k
r = n - k * p
res = str(r)[:4]
return f"{p}.{res} {size}"
else:
k = k2
p = n // k
r = n - k * p
res = str(r)[:4]
return f"{p}.{res} {size}"
def pid_info(fun: F) -> F:
@functools.wraps(fun)
def wrapped(*args: t.Any, **kwargs: t.Dict[str, t.Any]) -> t.Any:
p = psutil.Process(os.getpid())
with p.oneshot():
before = p.memory_full_info().rss
res = fun(*args, **kwargs)
with p.oneshot():
after = p.memory_full_info().rss
logger.debug(
f"'{fun.__name__}' released {fmt_bytes(before - after)}"
f" from {fmt_bytes(before)}"
)
return res
return t.cast(F, wrapped)
@pid_info
def jaxjit_cleanup(*jitted_functions: jax._src.pjit.JitWrapped) -> None:
for jit_fun in jitted_functions:
if hasattr(jit_fun, "clear_cache"):
jit_fun.clear_cache()
# JITs for `{jit_fun.__name__}` = {jit_fun._cache_size()}
del jit_fun._fun # type:ignore
del jit_fun
@pid_info
def jax_clear_backend() -> None:
jax.clear_backends() # type:ignore[no-untyped-call]
jax.clear_caches() # type:ignore[no-untyped-call]
@pid_info
def gc_collect() -> None:
gc.collect()
def jax_cleanup(fun: F) -> F:
"""Reclaims as much memory as possible from JAX. Meant to be called at the
end of the scope using JAX."""
@functools.wraps(fun)
def wrapped(*args: t.Any, **kwargs: t.Dict[str, t.Any]) -> t.Any:
res = fun(*args, **kwargs)
logger.debug(
f"JAX stats after '{fun.__name__}' : "
f"{control_flow._initial_style_jaxprs_with_common_consts.cache_info()}", # noqa: E501
)
jax_clear_backend()
gc_collect()
control_flow._initial_style_jaxprs_with_common_consts.cache_info()
return res
return t.cast(F, wrapped)