Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
jax / jax / _src / xla_bridge.py
Size: Mime:
# Copyright 2018 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Interface and utility functions to XLA.

This module wraps the XLA client(s) and builders to standardize their interfaces
and provide some automatic type mapping logic for converting between Numpy and
XLA. There are also a handful of related casting utilities.
"""

from collections.abc import Mapping
import dataclasses
from functools import partial, lru_cache
import importlib
import json
import logging
import os
import platform as py_platform
import pkgutil
import sys
import threading
from typing import Any, Callable, Optional, Union
import warnings

import numpy as np

from jax._src import lib
from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import bool_env, config, int_env
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import traceback_util
from jax._src import util

logger = logging.getLogger(__name__)

jax_plugins: Optional[Any]
try:
  import jax_plugins  # type: ignore
except ModuleNotFoundError:
  jax_plugins = None
except ImportError as e:
  logger.error("Failed to import jax_plugins: %s", e)
  jax_plugins = None

traceback_util.register_exclusion(__file__)


XlaBackend = xla_client.Client


# TODO(phawkins): Remove jax_xla_backend.
_XLA_BACKEND = jax_config.DEFINE_string(
    'jax_xla_backend', '',
    'Deprecated, please use --jax_platforms instead.')
BACKEND_TARGET = jax_config.DEFINE_string(
    'jax_backend_target',
    os.getenv('JAX_BACKEND_TARGET', '').lower(),
    'Either "local" or "rpc:address" to connect to a remote service target.')
# TODO(skye): warn when this is used once we test out --jax_platforms a bit
_PLATFORM_NAME = jax_config.DEFINE_string(
    'jax_platform_name',
    os.getenv('JAX_PLATFORM_NAME', '').lower(),
    'Deprecated, please use --jax_platforms instead.')
_DISABLE_MOST_OPTIMIZATIONS = jax_config.DEFINE_bool(
    'jax_disable_most_optimizations',
    bool_env('JAX_DISABLE_MOST_OPTIMIZATIONS', False),
    'Try not to do much optimization work. This can be useful if the cost of '
    'optimization is greater than that of running a less-optimized program.')
_XLA_PROFILE_VERSION = jax_config.DEFINE_integer(
    'jax_xla_profile_version', int_env('JAX_XLA_PROFILE_VERSION', 0),
    'Optional profile version for XLA compilation. '
    'This is meaningful only when XLA is configured to '
    'support the remote compilation profile feature.')
CUDA_VISIBLE_DEVICES = jax_config.DEFINE_string(
    'jax_cuda_visible_devices', 'all',
    'Restricts the set of CUDA devices that JAX will use. Either "all", or a '
    'comma-separate list of integer device IDs.')
_ROCM_VISIBLE_DEVICES = jax_config.DEFINE_string(
    'jax_rocm_visible_devices', 'all',
    'Restricts the set of ROCM devices that JAX will use. Either "all", or a '
    'comma-separate list of integer device IDs.')


def get_compile_options(
    num_replicas: int,
    num_partitions: int,
    device_assignment=None,
    use_spmd_partitioning: bool = True,
    use_auto_spmd_partitioning: bool = False,
    auto_spmd_partitioning_mesh_shape=[],
    auto_spmd_partitioning_mesh_ids=[],
    env_options_overrides: Optional[dict[str, str]] = None,
    fdo_profile: Optional[bytes] = None,
) -> xla_client.CompileOptions:
  """Returns the compile options to use, as derived from flag values.

  Args:
    num_replicas: Number of replicas for which to compile.
    num_partitions: Number of partitions for which to compile.
    device_assignment: Optional ndarray of jax devices indicating the assignment
      of logical replicas to physical devices (default inherited from
      xla_client.CompileOptions). Must be consistent with `num_replicas` and
      `num_partitions`.
    use_spmd_partitioning: boolean indicating whether to enable SPMD or MPMD
      partitioning in XLA.
    use_auto_spmd_partitioning: boolean indicating whether to automatically
      generate XLA shardings for SPMD partitioner.
    auto_spmd_partitioning_mesh_shape: device mesh shape used to create
      auto_spmd_partitioning search space.
    auto_spmd_partitioning_mesh_ids: device ids used to create
      auto_spmd_partitioning search space.
    env_options_overrides: dict of additional options parsed by the compiler
    fdo_profile: Optional profile for feedback-directed optimization passed to
    XLA.
  """
  compile_options = xla_client.CompileOptions()
  compile_options.num_replicas = num_replicas
  compile_options.num_partitions = num_partitions
  build_options = compile_options.executable_build_options
  build_options.use_spmd_partitioning = use_spmd_partitioning
  build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning
  if xla_extension_version > 165 and fdo_profile is not None:
    build_options.fdo_profile = fdo_profile
  if use_auto_spmd_partitioning:
    build_options.auto_spmd_partitioning_mesh_shape = auto_spmd_partitioning_mesh_shape
    build_options.auto_spmd_partitioning_mesh_ids = auto_spmd_partitioning_mesh_ids
  if device_assignment is not None:
    logger.debug(
        'get_compile_options: num_replicas=%s num_partitions=%s device_assignment=%s',
        num_replicas, num_partitions, device_assignment)
    device_assignment = np.array(device_assignment)

    # Allow 1D device assignment if num_partitions is 1.
    if (device_assignment.ndim == 1) and (num_partitions == 1):
      device_assignment = device_assignment[:, None]

    if num_replicas != device_assignment.shape[0]:
      msg = 'device_assignment does not match num_replicas: {} vs {}.'
      raise ValueError(msg.format(device_assignment, num_replicas))

    if num_partitions != device_assignment.shape[1]:
      msg = 'device_assignment does not match num_partitions: {} vs {}.'
      raise ValueError(msg.format(device_assignment, num_partitions))

    if device_assignment.dtype == object:
      device_assignment = np.vectorize(lambda d: d.id, otypes=[int])(
          device_assignment)
    device_assignment = xla_client.DeviceAssignment.create(device_assignment)
    assert device_assignment.replica_count() == num_replicas
    assert device_assignment.computation_count() == num_partitions
    compile_options.device_assignment = device_assignment

  if env_options_overrides is not None:
    compile_options.env_option_overrides = list(env_options_overrides.items())

  debug_options = compile_options.executable_build_options.debug_options
  if lib.cuda_path is not None:
    debug_options.xla_gpu_cuda_data_dir = lib.cuda_path

  if _DISABLE_MOST_OPTIMIZATIONS.value:

    debug_options.xla_backend_optimization_level = 0
    debug_options.xla_llvm_disable_expensive_passes = True
    debug_options.xla_test_all_input_layouts = False

  compile_options.profile_version = _XLA_PROFILE_VERSION.value
  return compile_options


# Backends

def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]:
  def _log_warning():
    warnings.warn(
      f'TPU backend initialization is taking more than {timer_secs} seconds. '
      'Did you run your code on all TPU hosts? '
      'See https://jax.readthedocs.io/en/latest/multi_process.html '
      'for more information.')

  # Will log a warning after `timer_secs`.
  t = threading.Timer(timer_secs, _log_warning)
  t.start()

  try:
    client = xla_client.make_tpu_client()
  finally:
    t.cancel()

  return client


# Backends
#
# We have no particular opinion about how "backends" relate to "devices". For
# example, there could be multiple backends that provide the same kind of
# device.

BackendFactory = Callable[[], Optional[xla_client.Client]]

@dataclasses.dataclass
class BackendRegistration:
  factory: BackendFactory

  # Priority of this backend when choosing a default backend. Higher = more
  # preferred.
  priority: int

  # If this backend fails to initialize, should we log a user-visible warning?
  # For plugins (e.g., TPU) we usually want a visible failure, because why
  # install a plugin if you don't intend it to be used?
  fail_quietly: bool = False

  # Is this plugin experimental? If a plugin is deemed experimental, we issue
  # a warning when it is initialized. This is mostly to set user expectations
  # correctly: we don't want users to think that JAX is buggy because of a
  # a buggy plugin.
  experimental: bool = False

_backend_factories: dict[str, BackendRegistration] = {}
_default_backend: Optional[xla_client.Client] = None
_backends : dict[str, xla_client.Client] = {}
_backends_errors : dict[str, str] = {}
_backend_lock = threading.Lock()

# The set of known non-experimental plugins.
#
# If a plugin passes the JAX test suite, it can be added to the allowlist below.
# Send a PR if you would like to be added.
#
# It is fine for a plugin not to implement every feature that JAX uses, provided
# that a reasonable feature set is implemented and the plugin fails gracefully
# for unimplemented features. Wrong outputs are not acceptable.
_nonexperimental_plugins: set[str] = set()

def register_backend_factory(name: str, factory: BackendFactory, *,
                             priority: int = 0,
                             fail_quietly: bool = True,
                             experimental: bool = False) -> None:
  with _backend_lock:
    if name in _backends:
      raise RuntimeError(f"Backend {name} already initialized")
  _backend_factories[name] = BackendRegistration(
    factory, priority, fail_quietly, experimental)


register_backend_factory('cpu',
                         partial(xla_client.make_cpu_client, use_tfrt=True),
                         priority=0,
                         fail_quietly=False)


def make_gpu_client(
    *, platform_name: str, visible_devices_flag: jax_config.FlagHolder[str]
) -> xla_client.Client:
  visible_devices = visible_devices_flag.value
  allowed_devices = None
  if visible_devices != "all":
    allowed_devices = {int(x) for x in visible_devices.split(",")}

  if xla_extension_version < 160:
    return xla_client.make_gpu_client(
        distributed_client=distributed.global_state.client,
        node_id=distributed.global_state.process_id,
        platform_name=platform_name,
        allowed_devices=allowed_devices,
    )
  else:
    # Remove `type: ignore` when the min jaxlib version (xla_extension_version)
    # >= 160.
    return xla_client.make_gpu_client(
        distributed_client=distributed.global_state.client,
        node_id=distributed.global_state.process_id,
        num_nodes=distributed.global_state.num_processes,
        platform_name=platform_name,
        allowed_devices=allowed_devices,
    )  # type: ignore


if hasattr(xla_client, "make_gpu_client"):
  register_backend_factory(
      "cuda",
      partial(
          make_gpu_client,
          platform_name="cuda",
          visible_devices_flag=CUDA_VISIBLE_DEVICES,
      ),
      priority=200,
      fail_quietly=True,
  )
  register_backend_factory(
      "rocm",
      partial(
          make_gpu_client,
          platform_name="rocm",
          visible_devices_flag=_ROCM_VISIBLE_DEVICES,
      ),
      priority=200,
      fail_quietly=True,
  )


if hasattr(xla_client, "make_tpu_client"):
  # TODO(phawkins,skyewm): switch TPU plugin to use the PJRT plugin mechanism,
  # and then fail loudly on initialization failure.
  register_backend_factory(
    'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300,
    fail_quietly=True)


def _get_pjrt_plugin_names_and_library_paths(
    plugins_from_env: str,
) -> dict[str, str]:
  """Gets the names and library paths of PJRT plugins to load from env var.

  Args:
    plugins_from_env: plugin name and pathes from env var. It is in the format
      of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for windows).

  Returns:
    A dict of {plugin_name: library path} for the PJRT plugins to load.
  """
  if not plugins_from_env:
    return {}

  pjrt_plugins = {}
  for plugin in plugins_from_env.split(','):
    try:
      name, library_path = plugin.split(os.path.pathsep)
      pjrt_plugins[name] = library_path
    except ValueError:
      logger.warning(
          'invalid value %s in env var PJRT_NAMES_AND_LIBRARY_PATHS %s',
          plugin,
          plugins_from_env,
      )
  return pjrt_plugins


def _get_pjrt_plugin_config(
    json_path: str,
) -> tuple[str, Optional[Mapping[str, Union[str, int, list[int], float]]]]:
  """Gets PJRT plugin configuration from a json file.

  The json file needs to have a "library_path" field for the plugin library
  path. It can have an optional "create_option" field for the options used when
  creating a PJRT plugin client. The value of "create_option" is key-value
  pairs. Please see xla_client._NameValueMapping for the supported types of
  values.
  """
  with open(json_path) as f:
    config = json.load(f)
  if 'library_path' not in config.keys():
    raise ValueError(
        'PJRT plugin config file should contain "library_path" field.'
    )
  return (config['library_path'], config.get('create_options'))

def discover_pjrt_plugins() -> None:
  """Discovers plugins in the namespace package `jax_plugins` and import them.

  There are two methods used to discover plugin modules. They are intended
  to be used together by implementors in order to cover all packaging and
  development cases:

  1. Define a globally unique module under the `jax_plugins` namespace
     package (i.e. just create a `jax_plugins` directory and define your
     module below it).
  2. If building a package via pyproject.toml or setup.py, advertise your
     plugin module name by including an entry-point under the `jax_plugins`
     group which points to your full module name.

  During Jax startup, Jax will load each module discovered in such a way and
  call its `initialize()` function. It is expected that this function should
  register its concrete plugin name/implementations via call(s) to
  `jax._src.xla_bridge.register_plugin(name, priority=, library_paty=,
  options=)`. Since `initialize()` functions are called for all installed
  plugins, they should avoid doing expensive, non-registration related work.

  TODO: We should provide a variant of `register_plugin` which allows the
  library_path and options to be resolved via a callback. This would enable
  light-weight plugin registration in cases where options need to be derived
  from heavy-weight system initialization.
  """
  plugin_modules = set()
  # Scan installed modules under |jax_plugins|. Note that not all packaging
  # scenarios are amenable to such scanning, so we also use the entry-point
  # method to seed the list.
  if jax_plugins:
    for _, name, _ in pkgutil.iter_modules(
        jax_plugins.__path__, jax_plugins.__name__ + '.'
    ):
      logger.debug("Discovered path based JAX plugin: %s", name)
      plugin_modules.add(name)
  else:
    logger.debug("No jax_plugins namespace packages available")

  # Augment with advertised entrypoints.
  if sys.version_info < (3, 10):
    # Use the backport library because it provides a forward-compatible
    # implementation.
    from importlib_metadata import entry_points
  else:
    from importlib.metadata import entry_points

  for entry_point in entry_points(group="jax_plugins"):
    logger.debug("Discovered entry-point based JAX plugin: %s",
                 entry_point.value)
    plugin_modules.add(entry_point.value)

  # Now load and initialize them all.
  for plugin_module_name in plugin_modules:
    logger.debug("Loading plugin module %s", plugin_module_name)
    plugin_module = None
    try:
      plugin_module = importlib.import_module(plugin_module_name)
    except ModuleNotFoundError:
      logger.warning("Jax plugin configuration error: Plugin module %s "
                     "does not exist", plugin_module_name)
    except ImportError:
      logger.exception("Jax plugin configuration error: Plugin module %s "
                       "could not be loaded")

    if plugin_module:
      try:
        plugin_module.initialize()
      except:
        logger.exception("Jax plugin configuration error: Exception when "
                         "calling %s.initialize()", plugin_module_name)


# TODO(b/261345120): decide on a public name and expose a public method which is
# an alias of this method.
def register_plugin(
    plugin_name: str,
    *,
    priority: int = 400,
    library_path: Optional[str] = None,
    options: Optional[Mapping[str, Union[str, int, list[int], float]]] = None,
) -> None:
  """Registers a backend factory for the PJRT plugin.

  Args:
    plugin_name: the name of the plugin.
    priority: the priority this plugin should be registered in jax backends.
      Default to be 400.
    library_path: Optional. The full path to the .so file of the plugin.
      Required when the plugin is dynamically linked.
    options: Optional. It is used when creating a PJRT plugin client.
  """
  def factory():
    # Plugin may already be statically linked in some configurations.
    if not xla_client.pjrt_plugin_loaded(plugin_name):
      if library_path is None:
        raise ValueError(
            'The library path is None when trying to dynamically load the'
            ' plugin.'
        )
      xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)

    if xla_extension_version < 165:
      return xla_client.make_c_api_client(plugin_name, options)
    else:
      if distributed.global_state.client is None:
        return xla_client.make_c_api_client(plugin_name, options, None)
      distribute_options = {
          'node_id': distributed.global_state.process_id,
          'num_nodes': distributed.global_state.num_processes,
      }
      if options is not None:
        distribute_options.update(options)
      return xla_client.make_c_api_client(
          plugin_name, distribute_options, distributed.global_state.client
      )


  logger.debug(
      'registering PJRT plugin %s from %s', plugin_name, library_path
  )
  experimental = plugin_name not in _nonexperimental_plugins
  register_backend_factory(plugin_name, factory, priority=priority,
                           fail_quietly=False, experimental=experimental)


def register_pjrt_plugin_factories_from_env() -> None:
  """Registers backend factories for PJRT plugins.

  A backend factory will be registered for every PJRT plugin in the input
  string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
  for windows). The path can be a path to the plugin library or a path to the
  plugin configuration json file. The json file needs to have a "library_path"
  field for the plugin library path. It can have an optional "create_option"
  field for the options used when creating a PJRT plugin client. The value of
  "create_option" is key-value pairs. Please see xla_client._NameValueMapping
  for the supported types of values.

  TPU PJRT plugin will be loaded and registered separately in make_tpu_client.
  """
  pjrt_plugins = _get_pjrt_plugin_names_and_library_paths(
      os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', '')
  )
  for plugin_name, path in pjrt_plugins.items():
    if path.endswith('.json'):
      library_path, options = _get_pjrt_plugin_config(path)
    else:
      library_path = path
      options = None
    logger.debug(
        'registering PJRT plugin %s from %s', plugin_name, library_path
    )
    register_plugin(plugin_name, library_path=library_path, options=options)


# Plugins in the namespace package `jax_plugins` will be imported.
discover_pjrt_plugins()
# Registers plugins names and paths set in env var PJRT_NAMES_AND_LIBRARY_PATHS,
# in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for
# windows).
register_pjrt_plugin_factories_from_env()

_platform_aliases = {
  "cuda": "gpu",
  "rocm": "gpu",
}

_alias_to_platforms: dict[str, list[str]] = {}
for _platform, _alias in _platform_aliases.items():
  _alias_to_platforms.setdefault(_alias, []).append(_platform)


def is_known_platform(platform: str) -> bool:
  # A platform is valid if there is a registered factory for it. It does not
  # matter if we were unable to initialize that platform; we only care that
  # we've heard of it and it isn't, e.g., a typo.
  return (platform in _backend_factories.keys() or
          platform in _platform_aliases.keys())


def canonicalize_platform(platform: str) -> str:
  """Replaces platform aliases with their concrete equivalent.

  In particular, replaces "gpu" with either "cuda" or "rocm", depending on which
  hardware is actually present. We want to distinguish "cuda" and "rocm" for
  purposes such as MLIR lowering rules, but in many cases we don't want to
  force users to care.
  """
  platforms = _alias_to_platforms.get(platform, None)
  if platforms is None:
    return platform

  b = backends()
  for p in platforms:
    if p in b.keys():
      return p
  raise RuntimeError(f"Unknown backend: '{platform}' requested, but no "
                     f"platforms that are instances of {platform} are present. "
                     "Platforms present are: " + ",".join(b.keys()))


def expand_platform_alias(platform: str) -> list[str]:
  """Expands, e.g., "gpu" to ["cuda", "rocm"].

  This is used for convenience reasons: we expect cuda and rocm to act similarly
  in many respects since they share most of the same code.
  """
  return _alias_to_platforms.get(platform, [platform])

def is_gpu(platform):
  return platform in ("cuda", "rocm")

def backends() -> dict[str, xla_client.Client]:
  global _backends
  global _backends_errors
  global _default_backend

  with _backend_lock:
    if _backends:
      return _backends
    if config.jax_platforms:
      jax_platforms = config.jax_platforms.split(",")
      platforms = []
      # Allow platform aliases in the list of platforms.
      for platform in jax_platforms:
        platforms.extend(expand_platform_alias(platform))
      priorities = range(len(platforms), 0, -1)
      # If the user specified a list of platforms explicitly, always fail
      # loudly.
      fail_quietly_list = [False] * len(platforms)
      platform_registrations = list(
        zip(platforms, priorities, fail_quietly_list))
    else:
      platform_registrations = list(
          (platform, registration.priority, registration.fail_quietly)
          for platform, registration
          in _backend_factories.items()
      )
    default_priority = -1000
    for platform, priority, fail_quietly in platform_registrations:
      try:
        backend = _init_backend(platform)
        _backends[platform] = backend

        if priority > default_priority:
          _default_backend = backend
          default_priority = priority
      except Exception as err:
        err_msg = f"Unable to initialize backend '{platform}': {err}"
        if fail_quietly:
          _backends_errors[platform] = str(err)
          logger.info(err_msg)
        else:
          if config.jax_platforms:
            err_msg += " (set JAX_PLATFORMS='' to automatically choose an available backend)"
          else:
            err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
          raise RuntimeError(err_msg)

    assert _default_backend is not None
    # We don't warn about falling back to CPU on Mac OS, because we don't
    # support anything else there at the moment and warning would be pointless.
    if (py_platform.system() != "Darwin" and
        _default_backend.platform == "cpu" and
        _PLATFORM_NAME.value != 'cpu'):
      logger.warning('No GPU/TPU found, falling back to CPU. '
                      '(Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)')
    return _backends


def _clear_backends() -> None:
  global _backends
  global _backends_errors
  global _default_backend

  logger.info("Clearing JAX backend caches.")
  with _backend_lock:
    _backends = {}
    _backends_errors = {}
    _default_backend = None

  get_backend.cache_clear()


def _init_backend(platform: str) -> xla_client.Client:
  registration = _backend_factories.get(platform, None)
  if registration is None:
    raise RuntimeError(
        f"Backend '{platform}' is not in the list of known backends: "
        f"{list(_backend_factories.keys())}.")

  if registration.experimental:
    logger.warning(f"Platform '{platform}' is experimental and not all JAX "
                   "functionality may be correctly supported!")
  logger.debug("Initializing backend '%s'", platform)
  backend = registration.factory()
  # TODO(skye): consider raising more descriptive errors directly from backend
  # factories instead of returning None.
  if backend is None:
    raise RuntimeError(f"Could not initialize backend '{platform}'")
  if backend.device_count() == 0:
    raise RuntimeError(f"Backend '{platform}' provides no devices.")
  util.distributed_debug_log(("Initialized backend", backend.platform),
                             ("process_index", backend.process_index()),
                             ("device_count", backend.device_count()),
                             ("local_devices", backend.local_devices()))
  logger.debug("Backend '%s' initialized", platform)
  return backend


def _get_backend_uncached(
    platform: Union[None, str, xla_client.Client] = None
) -> xla_client.Client:
  # TODO(mattjj,skyewm): remove this input polymorphism after we clean up how
  # 'backend' values are handled
  if platform is not None and not isinstance(platform, str):
    return platform

  platform = (platform or _XLA_BACKEND.value or _PLATFORM_NAME.value or None)

  bs = backends()
  if platform is not None:
    platform = canonicalize_platform(platform)
    backend = bs.get(platform, None)
    if backend is None:
      if platform in _backends_errors:
        raise RuntimeError(f"Backend '{platform}' failed to initialize: "
                           f"{_backends_errors[platform]}")
      raise RuntimeError(f"Unknown backend {platform}")
    return backend
  else:
    assert _default_backend is not None
    return _default_backend


@lru_cache(maxsize=None)  # don't use util.memoize because there is no X64 dependence.
def get_backend(
    platform: Union[None, str, xla_client.Client] = None
) -> xla_client.Client:
  return _get_backend_uncached(platform)


def get_device_backend(
    device: Optional[xla_client.Device] = None,
) -> xla_client.Client:
  """Returns the Backend associated with `device`, or the default Backend."""
  if device is not None:
    return device.client
  return get_backend()


def device_count(
    backend: Optional[Union[str, xla_client.Client]] = None
) -> int:
  """Returns the total number of devices.

  On most platforms, this is the same as :py:func:`jax.local_device_count`.
  However, on multi-process platforms where different devices are associated
  with different processes, this will return the total number of devices across
  all processes.

  Args:
    backend: This is an experimental feature and the API is likely to change.
      Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
      ``'tpu'``.

  Returns:
    Number of devices.

  """
  return int(get_backend(backend).device_count())


def local_device_count(
    backend: Optional[Union[str, xla_client.Client]] = None
) -> int:
  """Returns the number of devices addressable by this process."""
  return int(get_backend(backend).local_device_count())


def devices(
    backend: Optional[Union[str, xla_client.Client]] = None
) -> list[xla_client.Device]:
  """Returns a list of all devices for a given backend.

  .. currentmodule:: jaxlib.xla_extension

  Each device is represented by a subclass of :class:`Device` (e.g.
  :class:`CpuDevice`, :class:`GpuDevice`). The length of the returned list is
  equal to ``device_count(backend)``. Local devices can be identified by
  comparing :attr:`Device.process_index` to the value returned by
  :py:func:`jax.process_index`.

  If ``backend`` is ``None``, returns all the devices from the default backend.
  The default backend is generally ``'gpu'`` or ``'tpu'`` if available,
  otherwise ``'cpu'``.

  Args:
    backend: This is an experimental feature and the API is likely to change.
      Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
      ``'tpu'``.

  Returns:
    List of Device subclasses.
  """
  return get_backend(backend).devices()


def default_backend() -> str:
  """Returns the platform name of the default XLA backend."""
  return get_backend(None).platform

@lru_cache
def local_devices(process_index: Optional[int] = None,
                  backend: Optional[Union[str, xla_client.Client]] = None,
                  host_id: Optional[int] = None) -> list[xla_client.Device]:
  """Like :py:func:`jax.devices`, but only returns devices local to a given process.

  If ``process_index`` is ``None``, returns devices local to this process.

  Args:
    process_index: the integer index of the process. Process indices can be
      retrieved via ``len(jax.process_count())``.
    backend: This is an experimental feature and the API is likely to change.
      Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
      ``'tpu'``.

  Returns:
    List of Device subclasses.
  """
  if host_id is not None:
    warnings.warn(
        "The argument to jax.local_devices has been renamed from `host_id` to "
        "`process_index`. This alias will eventually be removed; please update "
        "your code.")
    process_index = host_id
  if process_index is None:
    process_index = get_backend(backend).process_index()
  if not (0 <= process_index < process_count()):
    raise ValueError(f"Unknown process_index {process_index}")
  return [d for d in devices(backend) if d.process_index == process_index]


def process_index(
    backend: Optional[Union[str, xla_client.Client]] = None
) -> int:
  """Returns the integer process index of this process.

  On most platforms, this will always be 0. This will vary on multi-process
  platforms though.

  Args:
    backend: This is an experimental feature and the API is likely to change.
      Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or
      ``'tpu'``.

  Returns:
    Integer process index.
  """
  return get_backend(backend).process_index()


# TODO: remove this sometime after jax 0.2.13 is released
def host_id(backend: Optional[Union[str, xla_client.Client]] = None) -> int:
  warnings.warn(
      "jax.host_id has been renamed to jax.process_index. This alias "
      "will eventually be removed; please update your code.")
  return process_index(backend)


@lru_cache
def process_count(
    backend: Optional[Union[str, xla_client.Client]] = None
) -> int:
  """Returns the number of JAX processes associated with the backend."""
  return max(d.process_index for d in devices(backend)) + 1


# TODO: remove this sometime after jax 0.2.13 is released
def host_count(backend: Optional[Union[str, xla_client.Client]] = None) -> int:
  warnings.warn(
      "jax.host_count has been renamed to jax.process_count. This alias "
      "will eventually be removed; please update your code.")
  return process_count(backend)


# TODO: remove this sometime after jax 0.2.13 is released
def host_ids(
    backend: Optional[Union[str, xla_client.Client]] = None
) -> list[int]:
  warnings.warn(
      "jax.host_ids has been deprecated; please use range(jax.process_count()) "
      "instead. jax.host_ids will eventually be removed; please update your "
      "code.")
  return list(range(process_count(backend)))


def using_pjrt_c_api(backend=None):
  return "PJRT C API" in get_backend(backend).platform_version


# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
def make_pjrt_tpu_topology(topology_name='', **kwargs):
  # TODO(b/261484192): Make a system for lazily loading libtpu.so and call
  # that inside make_tfrt_tpu_c_api_device_topology.
  get_backend()  # Properly initialize libtpu.so.
  return xla_client.make_tfrt_tpu_c_api_device_topology(
      topology_name, **kwargs
  )