Repository URL to install this package:
|
Version:
1.0.0.dev0 ▾
|
| sarus_llm |
| sarus_llm.egg-info |
| MANIFEST.in |
| PKG-INFO |
| README.md |
| pyproject.toml |
| setup.cfg |
| setup.py |
To install run
pipenv install --dev --skip-lock --extra-pip-args="-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
For the CI/CD, a requirements.txt file needs to be created for tox. Create it using
pipenv run pip freeze > requirements.txt && echo "-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" >> requirements.txt
Ampere machines enable the use of flash attention 2 and mixed precision training in bfloat16. To use these features, there are two environment variables that can be set:
JAX's sharding API is an implementation of the GSPDM paper.
GSPDM auto-completes the sharding on every tensor based on limited user annotation applying the following rules:
Preserved dimensions in outputs: some operators preserve some dimensions. The outputs are sharded along the same dimensions as the inputs for these preserved dimensions. This may not be optimal but this is more intuitive.
Merging compatible shardings: sharded dimensions may come from different inputs. XLA tries to merge these sharded dimensions in the output if they are compatible.
Iterative, priority-based sharding propagations: the compiler propagates
the user-specified sharding through operators. Element-wise operations (max,
activations, add,...) are assigned higher sharding priorities, since it is
more logical for the inputs and outputs of such operations to have the same
sharding. Operator like Dot that change the dimensions are assigned lower
priorities (propagated later).
Partial specification: the API supports partial user annotation, where the user let the program decides how to shard some tensors on some dimensions.
Guide for users: sometimes multiple sharding patterns are possibles and the user might want to annotate the result to guide the compiler.
The VM launching the scripts needs to have the permissions to handle TPUs. If
you launch the scripts from a Google cloud VM set Cloud API access scopes: Allow full access to all Cloud APIs under API and identity management in the
VM configuration.
Run source ./scripts/devices/tpu/commands.sh to expose commands to manage TPU
jobs:
tpu create v2-32tpu setuptpu copytpu copy $FILE_PATH (e.g. tpu copy examples)tpu run '$COMMAND' (e.g. tpu run 'python examples/train.py)tpu launch '$COMMAND'tpu check or tpu log (check every 1 second)tpu memory (run jax-smi)tpu stoptpu deletetpu tensorboard $LOGDIRgcloud compute scp ~/.ssh/gcp ng-gpu-vm:~/.ssh/id_rsa --compress --zone=europe-west4-a --project=sarus-ai
gcloud compute scp ~/.ssh/gcp.pub ng-gpu-vm:~/.ssh/id_rsa.pub --compress --zone=europe-west4-a --project=sarus-ai
gcloud compute ssh --zone=europe-west4-a --project=sarus-ai ng-gpu-vm --command='git config --global user.email "<email address>"'
gcloud compute ssh --zone=europe-west4-a --project=sarus-ai ng-gpu-vm --command='git config --global user.name "<Full Name>"'
LLama models can be run via the script given in examples/ex_llama.py. Good to know:
requirements_llama.txttrain_state``, change the argument of prepare_for_kbit_training`.hooks`` method of opacus, it is possible to set expanded weights by changing in the method add_dp_train_state: grad_sample_modeargument in bothprepare_moduleandprepare_optimizertoew`.