Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
  sarus_llm
  sarus_llm.egg-info
  MANIFEST.in
  PKG-INFO
  README.md
  pyproject.toml
  setup.cfg
  setup.py
Size: Mime:
  README.md

LLM

Installation

To install run

pipenv install --dev --skip-lock --extra-pip-args="-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"

For the CI/CD, a requirements.txt file needs to be created for tox. Create it using

pipenv run pip freeze > requirements.txt && echo "-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" >> requirements.txt

Torch Models and env vars for mixed precision training

Ampere machines enable the use of flash attention 2 and mixed precision training in bfloat16. To use these features, there are two environment variables that can be set:

  • FLASH_ATTENTION_2: can be set either to True or False
  • FLOAT_PRECISION: can be either float16/bfloat16 or None. In the last case, it will defaults to float32. Note that FLOAT PRECISION is used to set the dtype of the foundation model weights and the activations during training via torch.autocast.

Sharding

JAX's sharding API is an implementation of the GSPDM paper.

GSPDM auto-completes the sharding on every tensor based on limited user annotation applying the following rules:

  • Preserved dimensions in outputs: some operators preserve some dimensions. The outputs are sharded along the same dimensions as the inputs for these preserved dimensions. This may not be optimal but this is more intuitive.

  • Merging compatible shardings: sharded dimensions may come from different inputs. XLA tries to merge these sharded dimensions in the output if they are compatible.

  • Iterative, priority-based sharding propagations: the compiler propagates the user-specified sharding through operators. Element-wise operations (max, activations, add,...) are assigned higher sharding priorities, since it is more logical for the inputs and outputs of such operations to have the same sharding. Operator like Dot that change the dimensions are assigned lower priorities (propagated later).

  • Partial specification: the API supports partial user annotation, where the user let the program decides how to shard some tensors on some dimensions.

  • Guide for users: sometimes multiple sharding patterns are possibles and the user might want to annotate the result to guide the compiler.

Multi-host TPU

Admin

The VM launching the scripts needs to have the permissions to handle TPUs. If you launch the scripts from a Google cloud VM set Cloud API access scopes: Allow full access to all Cloud APIs under API and identity management in the VM configuration.

Launch TPU job

Run source ./scripts/devices/tpu/commands.sh to expose commands to manage TPU jobs:

  • Create a TPU pod VM tpu create v2-32
  • Copy code & install dependencies tpu setup
  • Copy all required files tpu copy
  • Copy a single file or folder tpu copy $FILE_PATH (e.g. tpu copy examples)
  • Run command tpu run '$COMMAND' (e.g. tpu run 'python examples/train.py)
  • Run detached command tpu launch '$COMMAND'
  • Monitor detached command tpu check or tpu log (check every 1 second)
  • Monitor memory usage tpu memory (run jax-smi)
  • Stop launched computation tpu stop
  • Delete the TPU pod VM tpu delete
  • Run tensorboard tpu tensorboard $LOGDIR

Setup remote computation

gcloud compute scp ~/.ssh/gcp ng-gpu-vm:~/.ssh/id_rsa --compress --zone=europe-west4-a --project=sarus-ai gcloud compute scp ~/.ssh/gcp.pub ng-gpu-vm:~/.ssh/id_rsa.pub --compress --zone=europe-west4-a --project=sarus-ai gcloud compute ssh --zone=europe-west4-a --project=sarus-ai ng-gpu-vm --command='git config --global user.email "<email address>"' gcloud compute ssh --zone=europe-west4-a --project=sarus-ai ng-gpu-vm --command='git config --global user.name "<Full Name>"'

Dataset to test ?

Software to use ?

https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-and-Flax--VmlldzoyMzA4ODEy

LLama

LLama models can be run via the script given in examples/ex_llama.py. Good to know:

  • the list of packages is in requirements_llama.txt
  • the first time the weights are loaded and converted to hf format, they have to fit in the cpu memory.
  • gradient checkpointing is activated by default to deactivate it, in the method train_state``, change the argument of prepare_for_kbit_training`.
  • per sample gradients are computed via the hooks`` method of opacus, it is possible to set expanded weights by changing in the method add_dp_train_state: grad_sample_modeargument in bothprepare_moduleandprepare_optimizertoew`.

Setting up nvidia docker