Repository URL to install this package:
|
Version:
1.1.3 ▾
|
sarus-llm
/
PKG-INFO
|
|---|
Metadata-Version: 2.1
Name: sarus-llm
Version: 1.1.3
Summary: A library to train LLMs with Sarus
Author: Sarus
License: PRIVATE
Classifier: Programming Language :: Python :: 3.9
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: tqdm==4.66.1
Requires-Dist: pyarrow~=15.0
Requires-Dist: datasets
Requires-Dist: dp-accounting
Requires-Dist: gcsfs
Requires-Dist: torch>=2.2.0
Requires-Dist: bitsandbytes
Requires-Dist: sentencepiece
Requires-Dist: datasets
Requires-Dist: mistral-common~=1.2.1
Requires-Dist: fastDP@ git+https://github.com/awslabs/fast-differential-privacy.git@c5146e10b191a259d29f85d8319be247584a9bdf
Requires-Dist: deepspeed~=0.15.0
Requires-Dist: tiktoken==0.4.0
Requires-Dist: blobfile
Requires-Dist: omegaconf~=2.3.0
Requires-Dist: safetensors
Requires-Dist: transformers
Requires-Dist: liger-kernel~=0.2.1; sys_platform == "linux"
Provides-Extra: development
Requires-Dist: pytest>=6.2; extra == "development"
Requires-Dist: pytest-mock>=3.6; extra == "development"
Requires-Dist: pytest-cov>=2.12; extra == "development"
Requires-Dist: mypy==1.9.0; extra == "development"
Requires-Dist: pre-commit; extra == "development"
Requires-Dist: ruff; extra == "development"
Requires-Dist: types-tqdm; extra == "development"
# LLM
## Installation
To install run
```
pipenv install --dev --skip-lock --extra-pip-args="-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
```
For the CI/CD, a `requirements.txt` file needs to be created for tox. Create it using
```
pipenv run pip freeze > requirements.txt && echo "-f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html" >> requirements.txt
```
## Torch Models and env vars for mixed precision training
Ampere machines enable the use of flash attention 2 and mixed precision training in bfloat16. To use these features, there
are two environment variables that can be set:
- FLASH_ATTENTION_2: can be set either to True or False
- FLOAT_PRECISION: can be either float16/bfloat16 or None. In the last case, it will defaults to float32. Note that FLOAT PRECISION is used to set
the dtype of the foundation model weights and the activations during training via torch.autocast.
## Sharding
JAX's sharding API is an implementation of the [GSPDM
paper](https://arxiv.org/pdf/2105.04663.pdf).
GSPDM auto-completes the sharding on every tensor based on limited user
annotation applying the following rules:
- **Preserved dimensions in outputs**: some operators preserve some dimensions.
The outputs are sharded along the same dimensions as the inputs for these
preserved dimensions. This may not be optimal but this is more intuitive.
- **Merging compatible shardings**: sharded dimensions may come from different
inputs. XLA tries to merge these sharded dimensions in the output if they are
compatible.
- **Iterative, priority-based sharding propagations**: the compiler propagates
the user-specified sharding through operators. Element-wise operations (max,
activations, add,...) are assigned higher sharding priorities, since it is
more logical for the inputs and outputs of such operations to have the same
sharding. Operator like `Dot` that change the dimensions are assigned lower
priorities (propagated later).
- **Partial specification**: the API supports partial user annotation, where
the user let the program decides how to shard some tensors on some
dimensions.
- **Guide for users**: sometimes multiple sharding patterns are possibles and
the user might want to annotate the result to guide the compiler.
## Multi-host TPU
### Admin
The VM launching the scripts needs to have the permissions to handle TPUs. If
you launch the scripts from a Google cloud VM set `Cloud API access scopes:
Allow full access to all Cloud APIs` under `API and identity management` in the
VM configuration.
### Launch TPU job
Run `source ./scripts/devices/tpu/commands.sh` to expose commands to manage TPU
jobs:
- Create a TPU pod VM `tpu create v2-32`
- Copy code & install dependencies `tpu setup`
- Copy all required files `tpu copy`
- Copy a single file or folder `tpu copy $FILE_PATH` (e.g. `tpu copy examples`)
- Run command `tpu run '$COMMAND'` (e.g. `tpu run 'python examples/train.py`)
- Run detached command `tpu launch '$COMMAND'`
- Monitor detached command `tpu check` or `tpu log` (check every 1 second)
- Monitor memory usage `tpu memory` (run jax-smi)
- Stop launched computation `tpu stop`
- Delete the TPU pod VM `tpu delete`
- Run tensorboard `tpu tensorboard $LOGDIR`
## Setup remote computation
`gcloud compute scp ~/.ssh/gcp ng-gpu-vm:~/.ssh/id_rsa --compress --zone=europe-west4-a --project=sarus-ai`
`gcloud compute scp ~/.ssh/gcp.pub ng-gpu-vm:~/.ssh/id_rsa.pub --compress --zone=europe-west4-a --project=sarus-ai`
`gcloud compute ssh --zone=europe-west4-a --project=sarus-ai ng-gpu-vm --command='git config --global user.email "<email address>"'`
`gcloud compute ssh --zone=europe-west4-a --project=sarus-ai ng-gpu-vm --command='git config --global user.name "<Full Name>"'`
## Dataset to test ?
- https://huggingface.co/datasets/sarus-tech/phee
- https://huggingface.co/datasets/flxclxc/encoded_drug_reviews
## Software to use ?
https://wandb.ai/jax-series/simple-training-loop/reports/Writing-a-Training-Loop-in-JAX-and-Flax--VmlldzoyMzA4ODEy
## LLama
LLama models can be run via the script given in `examples/ex_llama.py`. Good to know:
- the list of packages is in `requirements_llama.txt`
- the first time the weights are loaded and converted to hf format, they have to fit in the cpu memory.
- gradient checkpointing is activated by default to deactivate it, in the method `train_state``, change the argument of `prepare_for_kbit_training`.
- per sample gradients are computed via the `hooks`` method of opacus, it is possible to set expanded weights by changing
in the method `add_dp_train_state`: `grad_sample_mode` argument in both `prepare_module` and `prepare_optimizer` to `ew`.
## Setting up nvidia docker