Getting Started with Axolotl: Fine-Tuning Gemma 2B

Getting Started with Axolotl: Fine-Tuning Gemma 2B#

Axolotl is “a tool designed to streamline the fine-tuning of various AI models.” It is primarily for training Hugging Face models via full fine-tuning, lora, qlora, relora, gptq. Configurations are specified in yaml files. It supports a variety of different dataset formats. It supports additional libraries such as xformer and flash attention. It is compatible with FSDP and deepspeed for multi-gpu training. It supports logging to MLflow or WandB.

The recommended workflow is to pick a quickstart notebook from the examples directory and modify it as needed.

Let’s fine-tune the Gemma 2B model using Axolotl. There is already an example script for fine-tuning the 7B model, so we will adapt that to our needs.

Note that I ran all of this in a databricks worspace on one a10 GPU. I used qlora for fine-tuning.

Setup#

First, we install the necessary dependencies. Note that we are following the quickstart in the axolotl readme.

This part is important and less straightforward than it might seem. It is important to install the pytorch version corresponding to the correct CUDA version (see here). And make sure your pip version is up to date. Getting this part right took a fair amount of trial and error.

%pip install --upgrade pip
%pip install --upgrade torch==2.2.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

%pip install --upgrade mlflow
%pip install --upgrade packaging deepspeed transformers

Next, clone the axolotl repository.

%%sh
git clone https://github.com/OpenAccess-AI-Collective/axolotl

cd into the axolotl directory and install Axolotl, deepspeed, and flash-attention.

%%sh
cd axolotl
pip install -e '.[flash-attn,deepspeed]'

Next, because Gemma 2 is in a gated repo on Hugging Face, we need to log in to Hugging Face before we obtain and train the model.

from huggingface_hub import login
login()

Obtain and modify the training YAML file#

Axolotl is a configuration-based fine-tuning tool. Let’s get the configuration from the gemma 7b example file and modify it.

%%sh
wget https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/gemma/qlora.yml
--2024-03-28 16:49:54--  https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/gemma/qlora.yml
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8003::154, 2606:50c0:8000::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1095 (1.1K) [text/plain]
Saving to: ‘qlora.yml’

     0K .                                                     100% 45.4M=0s

2024-03-28 16:49:55 (45.4 MB/s) - ‘qlora.yml’ saved [1095/1095]

We make the following modifications to the configuration:

  • change the model to google/gemma-2b

  • change sequence_len to 2048 (otherwise we will encounter OOM errors)

With this completed, we can run the qlora fine-tuning job.

Pre-process the data#

The axolotl repo quickstart recommends pre-processing the data before training. This can be done as follows (using the example dataset):

%%sh
CUDA_VISIBLE_DEVICES="" python -m axolotl.cli.preprocess examples/openllama-3b/lora.yml

Train the Model#

%%sh
accelerate launch -m axolotl.cli.train ./qlora.yml

Resolving Issues#

Hopefully, this just works as-is. However, I ran into a few different issues across the different environments I tested this in. Here are a few tips for debugging:

  • Double check your python, torch, and CUDA versions. At the time of writing, axolotl requires Python >=3.10 and Pytorch >=2.1.1.

  • Make sure Pytorch is compiled with the correct CUDA version.

  • Make sure pip is up to date.

  • Get the torch/cuda installations figured out before worrying about the other dependencies.

There were a few databricks-specific issues as well.

  • I had to run databricks configure in the CLI and put in my credentials, otherwise I ran into errors. I believe this is due to the transformers library attempting to autolog to mlflow.

  • to log to MLflow, the following lines were necessary in the config:

# mlflow configuration
mlflow_tracking_uri: "<your_mlflow_tracking_uri>"
mlflow_experiment_name: "<your_mlflow_experiment>"
# optionally, save checkpoints to artifact registry
hf_mlflow_log_artifacts: false

To avoid using MLflow, even without this config, I needed to prepend the accelerate launch... command with export DISABLE_MLFLOW_INTEGRATION=true. Otherwise the autologging built into the transformers library (I think?) would attempt to log to mlflow, but would encounter an error without a run or experiment set.

For more, see this guide on debugging axolotl.