Fine-Tuning TinyLlama on a Single GPU#

The TinyLlama project “aims to pretrain a 1.1B Llama model on 3 trillion tokens.” 1.1B tokens represents a considerable step up from the small GPT model we previously fine-tuned. That model had 124M parameters; TinyLlama, while still small by the standards of most widely-used LLMs, is almost ten times the size. We will need around 20GB VRAM at a bare minimum to fine-tune this model with mixed precision.

Instruction Tuning#

We are going to focus on instruction tuning in this example. Instruction Tuning is a supervised learning technique in which we train the model on instruction/output pairs with the goal of training the model to follow human instructions. Before instruction tuning, the base model is trained on next-token completion. We saw this in the GPT2 example: we provided the start of a story and the model completed it. An instruction-tuned model, on the other hand, is trained to answer a question or instruction.

This repository contains a wealth of information on the current state of the field of instruction tuning.

The TinyLlama repository includes scripts for fine-tuning. While these will be useful references, we will try to proceed with an approach similar to that used in the gpt2 and t5-small notebooks–purely for the sake of making this notebook a reasonable learning step following those.

Here’s an example to make the idea of instruction tuning clearer. These are actual results from the base TinyLlama and the fine-tuned TinyLlama models:

Question: “How do I make coffee with a Chemex coffee maker?”

Answer before instruction tuning:

How do I make coffee with a Chemex coffee maker? How do you make coffee with a Chemex? What is the best way to brew coffee with a Chemex? How do you make coffee with a Chemex filter? How do you make coffee with a Chemex filter and a Chemex? How do you make coffee with a Chemex filter and a Chemex filter? How do you make coffee with a Chemex filter and a Chemex filter? How do you make coffee with a Chemex filter and

Answer after instruction tuning

To make coffee using a Chemex coffee maker, follow these simple steps:

  1. Prepare the Chemex: First, ensure that your Chemex is clean and dry before use. You can wash it in warm water with mild dish soap or wipe it down with a soft cloth to remove any residue from previous uses.

  2. Fill the Chemex: Measure out one cup of ground coffee into the bottom of the Chemex. If you’re making a large batch, feel free to double this amount for extra strength.

  3. Add hot water: Pour about half an inch (15 millimeters) of boiling water over the coffee grounds. This will create steam that helps heat the water and dissolve the coffee. Be careful not to let the water touch the coffee beans as they could burn each other.

  4. Let the mixture steep: Allow the mixture to sit undisturbed for at least five minutes. The longer the water sits on top of the coffee, the stronger the flavor will become.

  5. Strain the liquid: Once the time has passed, carefully strain the liquid through a fine-meshed filter or cheesecloth. Use a spoon to gently push the liquid through the filter so that all the coffee grounds are removed.

  6. Serve: Now that the coffee is ready, pour the filtered liquid into a glass and enjoy! Alternatively, you can serve it straight from the Chemex by placing it under cold running water until it reaches room temperature. Remember, always read the instructions provided with your Chemex coffee maker to avoid potential hazards such as burning yourself while handling hot liquids.

So we’re not going to teach TinyLlama how to make coffee, but we will teach it how to respond to questions and instructions.

The Data#

We will use the SlimOrca dataset. This is a curated subset of the much larger OpenOrca dataset. Why this dataset? It’s one of the most popular sources of instruction data on Hugging Face, and its size is more manageable than the full OpenOrca dataset. That’s all!

1. Load the model and try some examples#

We’ll begin, as always, by loading the model and trying out some examples.

%pip install --upgrade -r ./tinyllama_requirements.txt
# Some Environment Setup
OUTPUT_DIR = "../results/TinyLlama/" # the path to the output directory; where model checkpoints will be saved
LOG_DIR = "../logs/TinyLlama/" # the path to the log directory; where logs will be saved
CACHE_DIR = "../cache/TinyLlama/" # the path to the cache directory; where cache files will be saved
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

model_ckpt = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

tokenizer = AutoTokenizer.from_pretrained(
    model_ckpt,
)

tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_ckpt,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)

We use a few techniques to reduce memory usage and improve efficiency when loading the model:

  • We load the model in bfloat16. This reduces the precision but also reduces the memory footprint, leaving more memory headroom for training.

  • We use flash attention, which allows us to use longer sequence lengths withoug as much memory overhead.

Test some Prompts#

We’ll start with a simple completion-structured prompt, which we know this model can handle. In this type of prompt, we provide a partial text and expect the model to finish it.

# Inference
def generate(prompt, max_new_tokens=100):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    gen_tokens = model.generate(input_ids, max_new_tokens=max_new_tokens,
                                eos_token_id=tokenizer.eos_token_id,
                                repetition_penalty=1.1)
    return tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

print(generate("Here are step-by-step instructions to make a great cup of coffee with a Chemex coffee maker:\n1."))

What happens if, instead, we ask a question or give an instruction?

# Question
print(generate("How do I make coffee with a Chemex coffee maker?"))
How do I make coffee with a Chemex coffee maker?
How do you make coffee with a Chemex?
What is the best way to brew coffee with a Chemex?
How do you make coffee with a Chemex filter?
How do you make coffee with a Chemex filter and a Chemex?
How do you make coffee with a Chemex filter and a Chemex filter?
How do you make coffee with a Chemex filter and a Chemex filter?
How do you make coffee with a Chemex filter and
# Instruction
print(generate("Tell me how to make coffee with a Chemex coffee maker."))
Tell me how to make coffee with a Chemex coffee maker.
I'm not sure if you can get the Chemex in the UK, but I think it's worth a try.
The Chemex is a great coffee maker and I love mine!
I have one of these and I love it. It makes a great cup of coffee every time.
I have a Chemex and I love it. It makes a great cup of coffee every time.
I have a Chemex and I love it. It makes a great cup

These did not work because the model has not been instruction tuned. Our task is to change that!

2. Getting and Exploring the Data#

We will be using the SlimOrca dataset, which contains 518,000 instruction and response pairs.

from datasets import load_dataset
from pathlib import Path

slimorca = load_dataset('Open-Orca/SlimOrca',
                           cache_dir=str(Path(CACHE_DIR) / "data"))

Here’s one record from the dataset:

import json
print(json.dumps(slimorca["train"][0], indent=4))
{
    "conversations": [
        {
            "from": "system",
            "value": "You are an AI assistant. You will be given a task. You must generate a detailed and long answer.",
            "weight": null
        },
        {
            "from": "human",
            "value": "Write an article based on this \"A man has been charged with murder and attempted murder after a woman and the man she was on a date with were stabbed at a restaurant in Sydney, Australia.\"",
            "weight": 0.0
        },
        {
            "from": "gpt",
            "value": "Title: Tragedy Strikes in Sydney: Victims Stabbed during Dinner Date, Suspect Charged with Murder and Attempted Murder\n\nIntroduction\n\nA seemingly ordinary evening took a tragic turn in Sydney, Australia when a romantic dinner date at a local restaurant ended in bloodshed, as a man allegedly attacked a couple, stabbing both victims. Local authorities have charged the suspect with murder and attempted murder in connection with the horrific crime. In this article, we delve into the chilling incident, the suspect, and the impact it has had on the community.\n\nThe Incident\n\nThe alarming event took place at a popular restaurant in Sydney, Australia. A woman and her date were enjoying their meal \u2013 blissfully unaware of the tragic circumstances that would soon unfold. As they shared conversation and indulged in their dinner, a man approached the couple and, without warning, began to stab them with a sharp object.\n\nPatrons at the restaurant scrambled for safety, as the assailant continued his violent rampage. Bystanders quickly called emergency services, who promptly arrived at the scene. Although both victims were rushed to a nearby hospital, the woman succumbed to her injuries. The man she was on a date with remained in critical condition, fighting for his life.\n\nThe Suspect\n\nInitial investigations revealed that the alleged attacker had no known connections to the victims \u2013 adding to the mystifying nature of this sudden and brutal assault. Witnesses reported that the assailant seemed to have no specific motive and appeared to carry out the act senselessly.\n\nFollowing a thorough investigation, local police identified and arrested the suspect. During the process, it was discovered that the alleged attacker had a history of criminal behavior and a troubled past, though it is unclear if this played a role in the tragic incident.\n\nAuthorities have formally charged the man with murder and attempted murder in connection with the heinous crime. He awaits a hearing to determine a trial date and, if convicted, could face a life sentence in prison.\n\nThe Community's Response\n\nThe shocking nature of the crime has left the residents of Sydney reeling, as they struggle to come to terms with the harrowing event. The restaurant where the attack occurred has since been closed, with a makeshift memorial being created outside to commemorate the victims.\n\nMany have questioned how such a vicious crime could happen in what is considered to be one of the safest cities in the world. This tragic event has spurred local officials to reassess current security measures and devise strategies to reinforce public safety. Additionally, discussions surrounding mental health and criminal rehabilitation have surfaced as residents seek to comprehend the actions of the alleged perpetrator and prevent similar incidents from occurring in the future.\n\nIn the wake of the stabbing, the community has banded together with an outpouring of grief and support for the victims and their families. Candlelight vigils have been held, and an online fundraising campaign is underway to assist the surviving victim with his medical expenses and recovery.\n\nConclusion\n\nThe tragic attack in Sydney serves as a chilling reminder that senseless acts of violence can happen anywhere and at any time. The community's response to this horrific and seemingly random act of brutality has been one of solidarity and determination to prevent such incidents in the future. As the case unfolds, the victims and their families remain in the hearts of the community, who are grieving the devastating loss of a life cut tragically short and supporting the recovering victim as he continues to endure this unimaginable ordeal.",
            "weight": 1.0
        }
    ]
}

You’ll see that there are three components to the sample entry:

  1. A system message: this should be familiar if you’ve used e.g. ChatGPT via the API. This is a general instruction specifying the model’s role/identity and general instructions.

  2. A human message: this is the specific instruction passed to the model by a human.

  3. a gpt: this is the AI model’s response.

So we want to use this dataset to fine-tune the model such that it will respond more like the gpt message when given the system and human messages.

3. Formatting the Data#

First, we need to get these entries into a format we can actually use for fine-tuning. Appendix A digs into the fine tuning code from the TinyLlama repo to see how the authors handled data formatting. We’re going to take a slightly different approach and use the chat model templates from the Transformers library. The Hugging Face docs recommend applying the chat templates as a preprocessing step. Let’s take a look at how they work.

Transformers Chat Templates#

Chat Templates are attributes of tokenizers. If a chat template isn’t set explicitly, the default template for that model class is used. Let’s see if there is a chat template set here.

print(tokenizer.chat_template), print(tokenizer.default_chat_template)
No chat template is defined for this tokenizer - using the default template for the LlamaTokenizerFast class. If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.
None
{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true and not '<<SYS>>' in messages[0]['content'] %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don\'t know the answer to a question, please don\'t share false information.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\n' + system_message + '\n<</SYS>>\n\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\n' + content.strip() + '\n<</SYS>>\n\n' }}{% elif message['role'] == 'assistant' %}{{ ' '  + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}
(None, None)

There is no chat template defined for this tokenizer, so we’ll use the default LlamaTokenizerFast class default template. In order to use the template, we first need to adjust the slimorca records to match the following format, with role and content instead of from and value keys, and system/assistant/user roles instead of system/gpt/human. The chat is still structured as a list of dictionaries. Here’s an example of a chat in the expected format:

chat = [
    {"role": "system", "content": "You are a helpful assistant and an expert at making coffee."},
    {"role": "user", "content": "How do I make coffee with a Chemex coffee maker?"},
    {"role": "assistant", "content": "To make coffee with a Chemex:\n1. Boil water to about 200°F (93°C).\n2. Place the Chemex filter in the top and rinse it with hot water to remove paper taste and warm the vessel. Discard the rinse water.\n3. Add coffee grounds to the filter. Use a medium-coarse grind, about 1 gram of coffee per 16 grams of water.\n4. Pour just enough hot water to saturate the grounds. Wait 30 seconds for the coffee to 'bloom'.\n5. Slowly pour the remaining water over the grounds in a circular motion. Aim for a total brew time of 3.5 to 4.5 minutes.\n6. Once brewing is complete, remove the filter and enjoy."}
]

Now we can apply the chat template and obtain a string-formatted chat that we can tokenize and train on.

print(tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False))
<s>[INST] <<SYS>>
You are a helpful assistant and an expert at making coffee.
<</SYS>>

How do I make coffee with a Chemex coffee maker? [/INST] To make coffee with a Chemex:
1. Boil water to about 200°F (93°C).
2. Place the Chemex filter in the top and rinse it with hot water to remove paper taste and warm the vessel. Discard the rinse water.
3. Add coffee grounds to the filter. Use a medium-coarse grind, about 1 gram of coffee per 16 grams of water.
4. Pour just enough hot water to saturate the grounds. Wait 30 seconds for the coffee to 'bloom'.
5. Slowly pour the remaining water over the grounds in a circular motion. Aim for a total brew time of 3.5 to 4.5 minutes.
6. Once brewing is complete, remove the filter and enjoy. </s>

Apply the template to the whole dataset#

Now we need to apply the template to the whole slimorca dataset. We will first convert the slimorca entries into the expected format, and then use tokenizer.apply_chat_template to apply the template.

Note that the instruction format includes some special tokens we would like to add to the tokenizer’s vocabulary. We begin by adding them with tokenizer.add_special_tokens. For some more details on special tokens, see the Data Preprocessing notebook, which goes into much greater detail on the whole preprocessing pipeline.

import torch

# configure the model and tokenizer with chat tokens
# Add the instruction tokens to the tokenizer
special_tokens = ["[INST]", "[/INST]", "<<SYS>>", "<</SYS>>"]
# Adding special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
# Update the model's embeddings accordingly
model.resize_token_embeddings(len(tokenizer))


def format_slimorca(ex):
    role_mapping = {"gpt": "assistant", "system": "system", "human": "user"}
    chat = [
        {"role": role_mapping[message["from"]], "content": message["value"]}
        for message in ex["conversations"]
    ]
    formatted_chat = tokenizer.apply_chat_template(
        chat,
        tokenize=False,  # Apply formatting but do not tokenize
        add_generation_prompt=False,
    )

    # Tokenize using the standard tokenizer method
    tokenized_output = tokenizer(
        formatted_chat,
        add_special_tokens=False,  # apply_chat_template already added special tokens
        padding="max_length",  # pad to the specified length
        max_length=512,  # max length at which to truncate or to which to pad
        truncation=True,  # truncate to the specified length
    )

    return tokenized_output


# Map to the dataset
slimorca_tokenized = slimorca.map(format_slimorca, num_proc=16).remove_columns(
    "conversations"
)

There are many decisions to make during data preprocessing. Here’s a summary of some of the choices we made for this training run:

  • We used the llama-style chat template, but there are many different chat templates available.

  • We set the max sequence length to 512 tokens, and padded or truncated each sequence as needed to make sure all of the sequences were of the same length.

  • We did not handle the instructions and responses differently. The model will still follow a typical causal language modeling approach: it will predict the next token given the previous tokens. There are approaches to, for example, make sure the loss is only calculated based on the response portion. For simplicity, we did not use them here.

What do our data look like after preprocessing?

slimorca_tokenized
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 517982
    })
})

We now have two features: our input ids and an attention mask. An attention mask is a list of 1s and 0s that indicate which tokens are part of the input and which should be ignored. This ensures that the model does not pay attention to padding tokens.

Now let’s inspect a single example and make sure it corresponds to the format we expect.

# Inspect one example
print(tokenizer.decode(slimorca_tokenized["train"][25]['input_ids']))
<s>[INST]  <<SYS>> 
You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.
<</SYS>> 

Read this: From 1981 to 2010, the average annual precipitation measured at Seattle–Tacoma International Airport was 37.49 inches (952 mm). Annual precipitation has ranged from 23.78 in (604 mm) in 1952 to 55.14 in (1,401 mm) in 1950; for water year (October 1 – September 30) precipitation, the range is 23.16 in (588 mm) in 1976–77 to 51.82 in (1,316 mm) in 1996–97. Due to local variations in microclimate, Seattle also receives significantly lower precipitation than some other locations west of the Cascades. Around 80 mi (129 km) to the west, the Hoh Rain Forest in Olympic National Park on the western flank of the Olympic Mountains receives an annual average precipitation of 142 in (3.61 m). Sixty miles to the south of Seattle, the state capital Olympia, which is out of the Olympic Mountains' rain shadow, receives an annual average precipitation of 50 in (1,270 mm). The city of Bremerton, about 15 mi (24 km) west of downtown Seattle, receives 56.4 in (1,430 mm) of precipitation annually.

What is the average rainfall in Seattle? 
What is the answer? (If it cannot be answered, return "unanswerable") [/INST]  The average annual precipitation measured at Seattle-Tacoma International Airport from 1981 to 2010 was 37.49 inches (952 mm).

The answer is 37.49 inches (952 mm). </s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>

Note the padding tokens at the end. The whole example was shorter than 512 tokens, so it was padded to reach 512 tokens.

Split the dataset into training and validation#

from datasets import DatasetDict

# Split the tokenized dataset into training and validation sets
slimorca_tokenized_split = slimorca_tokenized['train'].train_test_split(test_size=0.1)

slimorca_tokenized_split["train"] = slimorca_tokenized_split["train"]
slimorca_tokenized_split["test"] = slimorca_tokenized_split["test"]

# Format the split datasets into a DatasetDict for compatibility with Hugging Face's Trainer
slimorca_tokenized_split = DatasetDict(
    {
        "train": slimorca_tokenized_split["train"],
        "valid": slimorca_tokenized_split["test"],
    }
)

slimorca_tokenized_split

Now, as in the gpt2 example, we will configure a collator. The collator is responsible for taking inputs, generating labels, and assembling the inputs into batches. See the data preprocessing notebook for more details.

Since we already padded/truncated the inputs to the same lengths, we don’t need anything special here. The DataCollatorForLanguageModeling collator will add labels to each entry. Importantly, the labels are the same as the inputs. The trainer handles shifting the labels; we do not need to implement any custom logic to align the labels.

from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False,
)

4. Fine-tune the model#

Now that the data are ready, we can train the model using the Hugging Face Trainer. This part is similar to the earlier examples; however, we made a few changes to handle the limited memory.

  • We set auto_find_batch_size to True. The trainer will try multiple batch sizes, starting from the specified per_device_train_batch_size, and reduce the batch size if it encounters an OOM error.

  • We use gradient accumulation to simulate a larger batch size. Gradients are accumulated over multiple mini-batches of data (because we cannot use a very large batch size). The weights are only updated after the specified number of gradient accumulation steps.

from transformers import TrainingArguments, Trainer
import mlflow

# Define the training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=1,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=4,
    auto_find_batch_size=True,
    warmup_steps=1,
    weight_decay=0.01,
    logging_dir=LOG_DIR,
    logging_steps=25,  # Log every 25 steps
    evaluation_strategy="steps",  # Evaluate every 'eval_steps'
    eval_steps=5000,
    bf16=True,
    # fp16=True,
    gradient_accumulation_steps=4,
    gradient_checkpointing=False,
    # optim="adamw_bnb_8bit",
    save_steps=10000,
)

training_args.set_logging(report_to=["mlflow"], steps=50, level="info")


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=slimorca_tokenized_split["train"],
    eval_dataset=slimorca_tokenized_split["valid"],
    data_collator=data_collator,
)

# Start training and track with MLflow
with mlflow.start_run(log_system_metrics=True):
    trainer.train()
    mlflow.log_params(training_args.to_dict())

trainer.save_model(OUTPUT_DIR + "/final")

5. Load the Fine-Tuned Model Checkpoint and Run some Examples#

tokenizer = AutoTokenizer.from_pretrained(
    "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
)
tokenizer.pad_token = tokenizer.eos_token
special_tokens = ["[INST]", "[/INST]", "<<SYS>>", "<</SYS>>"]
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})


ckpt = OUTPUT_DIR + "/final"
model = AutoModelForCausalLM.from_pretrained(
    ckpt,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="flash_attention_2",
)
prompt = [
    {
        "role": "system",
        "content": "You are a helpful assistant and an expert at making coffee.",
    },
    {"role": "user", "content": "Tell me how to make coffee with a Chemex coffee maker."},
]
prompt = tokenizer.apply_chat_template(
    prompt, tokenize=False, add_generation_prompt=False
)


def generate(prompt, max_new_tokens=500):
    input_ids = tokenizer.apply_chat_template(
        prompt, tokenize=True, add_generation_prompt=True, return_tensors="pt"
    ).to(model.device)
    gen_tokens = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,
        eos_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.2,
        do_sample=False,
    )
    return tokenizer.batch_decode(gen_tokens, skip_special_tokens=False)[0]


print(generate(prompt))

Which results in:

To make coffee using the Chemex coffee maker, follow these steps:

  1. Prepare your Chemex coffee maker: First, ensure that you have all necessary materials for brewing coffee, such as a Chemex coffee filter (or any other type of filter), water, ground coffee beans, and a clean spoon or measuring cup.

  2. Measure the amount of water needed: The Chemex coffee maker has two settings - “hot” and “coffee”. You will need about one-third more water than required by the hot setting. For example, if the hot setting requires 8 cups of water, then you’ll need around 3/4ths of this amount in the coffee setting.

  3. Pour the water into the Chemex: Place the Chemex on its side and pour enough water from the tap into it so that there is approximately half an inch of water above the top of the filter. This should fill up most of the space between the bottom of the filter and the rim of the pot.

  4. Add the coffee beans: Carefully place the Chemex lid onto the pot, ensuring that the coffee bean grind level matches the desired strength. If you want a stronger coffee, add more grounds; if you prefer a milder flavor, use less.

  5. Turn on the machine: Close the lid tightly and turn on the machine. It may take some time for the water to heat up properly, but once it starts heating, let it continue to heat until the water reaches boiling point.

  6. Allow the water to reach a rolling boil: Once the water reaches a rolling boil, remove the lid and allow the steam to escape through the opening. Keep checking the temperature regularly to avoid burning yourself while working with the Chemex.

  7. Stir the water: After removing the lid, carefully stir the water with a clean spoon or measuring cup. This helps prevent lumps and uneven extraction of flavors.

  8. Let the water cool down: As soon as the water cools down, remove the lid and transfer the coffee to a serving container. Alternatively, you can leave the coffee in the pot overnight to extract even more flavor.

  9. Enjoy your coffee: Serve the coffee immediately after it has been extracted, or store it in an airtight container for later consumption.

Remember to always check the temperature before drinking the coffee, especially when storing it longer than recommended. Also, keep in mind that different types of coffee require varying amounts of water to achieve optimal taste and quality.

If you have any questions or concerns regarding the process, feel free to ask! 😊

Happy brewing! ☕️ ���

As we saw at the beginning…the model doesn’t know how to make coffee, but it does know how to answer questions now!

6. Next Steps#

This fine-tuning process pushed the limits of what we could accomplish on a single GPU. And it makes sense: our back-of-the-envelope calculations said that we would require at least 20GB of VRAM, before we even think about storing activations or scaling sequence lengths or batch sizes.

We got around this in part by using a smaller sequence length than that shown in the tinyllama fine-tuning script. The biggest change we made was to use the adamw_bnb_8bit optimizer from the bitsandbytes library. The point is that we are running up against the limits of what we can reasonably accomplish with a single GPU, at least without more sophisticated approaches. So what’s next? There are several directions we can pursue (and we can and should pursue them all):

  1. Try to further optimize training this model on a single GPU. What can we do to make the training process run faster and more effectively? Can we find an approach that will still let us the normal adamw optimizer? Can we benefit from using e.g. Deepspeed ZeRO?

  2. Try to fine-tune this model on a multi-GPU setup. What are the benefits in terms of speed and ability to train larger batches and larger sequence lengths? And, perhaps more importantly in this setting, how do we make the leap from a single GPU to a multi-GPU setup?

  3. Train a bigger model! So far we have fine-tuned t5-small, gpt2, and tinyllama, with each subsequent model larger than the last. We ultimately want to work our way up to even larger models, so after this, it might be time to train a 3B parameter model, and then a 7B parameter model!

Appendix A: Looking at the tinyllama fine-tuning code#

You can fine the tinyllama fine-tuning code here. It’s worth the time, at this phase of learning about fine-tuning, to read through it and learn about some of the approaches they use.

Let’s first take a look at the train() method. It begins by using the HFArgumentParser to configure the training arguments. Earlier, the training code defined a number of dataclasses for e.g. training arguments, data arguments, etc. HFArgumentParser provides an approach for parsing command line arguments directly into instances of these dataclass types. So instead of simply defining arguments in notebook cells, as we’ve been doing, this approach provides a structured way to parse command line arguments. And, indeed, the repo provides a shell script for running the fine-tuning script with a defined set of arguments.

The next major section of the training script is focused on preparing the data, using the make_data_module method defined earlier in the script. That method is set up to handle a few different potential fine-tuning data sources (slimorca is not included among them). It maps each of them to the expected format: an input string and an output string.

I found the handling of alpaca-formatted datasets instructive. The alpaca dataset includes instructions and optional inputs that follow a specified format. For examples with inputs, the format is:

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{instruction}

### Input:
{input}

### Response:

The following code snippet in the TinyLlama repo handles this formatting (in the alpaca dataset, the inputs/outputs are not pre-formatted)

ALPACA_PROMPT_DICT = {
    "prompt_input": (
        "Below is an instruction that describes a task, paired with an input that provides further context. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response: "
    ),
    "prompt_no_input": (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        "### Instruction:\n{instruction}\n\n### Response: "
    ),
}

def extract_alpaca_dataset(example):
    if example.get("input", "") != "":
        prompt_format = ALPACA_PROMPT_DICT["prompt_input"]
    else:
        prompt_format = ALPACA_PROMPT_DICT["prompt_no_input"]
    return {'input': prompt_format.format(**example)}

We’re seeing some repeating patterns across training scripts (the examples so far in this repo and the TinyLlama code). Each fine-tuning run so far requires the following:

  • process the data

  • set up training arguments

  • set up logging

An additional step, as we get to multi-gpu and multi-node setups, will be configuring devices and processses—-see the script.sh shell script from TinyLlama for an example, which uses accelerate launch, a helper command that makes it easier to launch training scripts on different hardware.

Appendix B: Resuming from a Checkpoint#

I made a few mistakes in terms of handling checkpoints. In one of those cases, I saved checkpoints and assumed a final model would also be saved. This was not the case. So I had a checkpoint at step 20,000 out of 29,000. In this case, instead of starting over, it made more sense to load the checkpoint and finish training. To do so with the Hugging Face trainer, we can:

  1. Load the desired model checkpoint with e.g.

model = AutoModelForCausalLM.from_pretrained(
    "/path/to/checkpoint-20000",
    torch_dtype=torch.bfloat16, 
    device_map="auto",
    attn_implementation="flash_attention_2"
)

Also make sure the tokenizer is loaded. 2. After configuring the trainer/training arguments as before, call trainer.train with the resume_from_checkpoint argument set to the desired checkpoint.

trainer.train(resume_from_checkpoint="/path/to/checkpoint-20000")

The training will then pick up at step 20,000. And then you can make sure to save the final model!