Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
Size: Mime:
import functools
import gc
import logging
import os
import typing as t

import jax
import psutil
from jax._src.lax import control_flow

logger = logging.getLogger(__name__)
stats = [0]

F = t.TypeVar("F", bound=t.Callable[..., t.Any])


def fmt_bytes(n: int) -> str:
    K = 1024
    sizes = ["B", "KB", "MB", "GB", "TB"]
    if n < 0:
        n = -n
    k = 1
    for size in sizes:
        k2 = K * k
        if n < k2:
            p = n // k
            r = n - k * p
            res = str(r)[:4]
            return f"{p}.{res} {size}"
        else:
            k = k2
    p = n // k
    r = n - k * p
    res = str(r)[:4]
    return f"{p}.{res} {size}"


def pid_info(fun: F) -> F:
    @functools.wraps(fun)
    def wrapped(*args: t.Any, **kwargs: t.Dict[str, t.Any]) -> t.Any:
        p = psutil.Process(os.getpid())
        with p.oneshot():
            before = p.memory_full_info().rss
        res = fun(*args, **kwargs)
        with p.oneshot():
            after = p.memory_full_info().rss
        logger.debug(
            f"'{fun.__name__}' released {fmt_bytes(before - after)}"
            f" from {fmt_bytes(before)}"
        )
        return res

    return t.cast(F, wrapped)


@pid_info
def jaxjit_cleanup(*jitted_functions: jax._src.pjit.JitWrapped) -> None:
    for jit_fun in jitted_functions:
        if hasattr(jit_fun, "clear_cache"):
            jit_fun.clear_cache()
            # JITs for `{jit_fun.__name__}` = {jit_fun._cache_size()}
            del jit_fun._fun  # type:ignore
        del jit_fun


@pid_info
def jax_clear_backend() -> None:
    jax.clear_backends()  # type:ignore[no-untyped-call]
    jax.clear_caches()  # type:ignore[no-untyped-call]


@pid_info
def gc_collect() -> None:
    gc.collect()


def jax_cleanup(fun: F) -> F:
    """Reclaims as much memory as possible from JAX. Meant to be called at the
    end of the scope using JAX."""

    @functools.wraps(fun)
    def wrapped(*args: t.Any, **kwargs: t.Dict[str, t.Any]) -> t.Any:
        res = fun(*args, **kwargs)
        logger.debug(
            f"JAX stats after '{fun.__name__}' : "
            f"{control_flow._initial_style_jaxprs_with_common_consts.cache_info()}",  # noqa: E501
        )
        jax_clear_backend()
        gc_collect()
        control_flow._initial_style_jaxprs_with_common_consts.cache_info()

        return res

    return t.cast(F, wrapped)