Instruction Tuning Olmo-1b#

Read this blog on the topic of instruction tuning OLMo-1b here on mlops.community

This notebook follows a very similar approach to that from the tinyllama instruction tuning notebook, making just a few small adjustments based on lessons learned in that process. The biggest change was that we trained this model using a much smaller training set after noticing that even the first checkpoint from fine-tuning tinyllama was able to response appropriately to instructions.

Setup#

%pip install --upgrade -r ./olmo_requirements.txt
# Some Environment Setup
OUTPUT_DIR = "../results/olmo/" # the path to the output directory; where model checkpoints will be saved
LOG_DIR = "../logs/olmo/" # the path to the log directory; where logs will be saved
CACHE_DIR = "../cache/olmo/" # the path to the cache directory; where cache files will be saved

Load the model and test some prompts#

# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import hf_olmo

model_ckpt = "allenai/OLMo-1B"

tokenizer = AutoTokenizer.from_pretrained(
    model_ckpt,
)

model = AutoModelForCausalLM.from_pretrained(
    model_ckpt,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

Note that, at the time of writing this notebook, flash attention was not yet working with OLMo.

Test some Prompts#

We’ll start with a simple completion-structured prompt, which we know this model can handle. In this type of prompt, we provide a partial text and expect the model to finish it.

# Inference
def generate(prompt, max_new_tokens=100):
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)
    gen_tokens = model.generate(input_ids, max_new_tokens=max_new_tokens,
                                eos_token_id=tokenizer.eos_token_id,
                                repetition_penalty=1.1)
    return tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0]

print(generate("Here are step-by-step instructions to make a great cup of coffee with a Chemex coffee maker:\n1."))
Here are step-by-step instructions to make a great cup of coffee with a Chemex coffee maker:
1. Fill the filter basket with ground coffee and place it in the bottom of your Chemex.
2. Add hot water to fill the rest of the way, about 1/4 cup (or more if you like strong coffee).
3. Place the filter on top of the water and let sit for 5 minutes.
4. Pour the water out and rinse the filter under cold running water.
5. Repeat steps 2 and 3 until all the water is used up.
6.

What happens if, instead, we ask a question or give an instruction? As the model has not been instruction tuned, these will not work.

# Question
print(generate("How do I make coffee with a Chemex coffee maker?"))
How do I make coffee with a Chemex coffee maker?
Step 1: Fill the filter basket with ground coffee. Step 2: Pour hot water into the reservoir and let it sit for about 30 seconds to infuse the grounds. Step 3: Add more hot water, if needed. Step 4: Stir the mixture until it is evenly distributed throughout the grounds.
What are the best Chemex filters?
The Best Chemex Filters of 2021 – Reviewed & Compared.
Can you use regular coffee filters in a Chemex?
You can
# Instruction
print(generate("Tell me how to make coffee with a Chemex coffee maker."))
Tell me how to make coffee with a Chemex coffee maker.
I have a chemex and I love it! It's so easy to use, and the results are great. I would definitely recommend it to anyone who is looking for a good coffee maker.

These did not work because the model has not been instruction tuned. Our task is to change that!

Get and Explore the Data#

As in the tinyllama notebook will be using the SlimOrca dataset, which contains 518,000 instruction and response pairs.

from datasets import load_dataset
from pathlib import Path

slimorca = load_dataset('Open-Orca/SlimOrca',
                           cache_dir=str(Path(CACHE_DIR) / "data"))

Here’s one record from the dataset:

import json
print(json.dumps(slimorca["train"][0], indent=4))
{
    "conversations": [
        {
            "from": "system",
            "value": "You are an AI assistant. You will be given a task. You must generate a detailed and long answer.",
            "weight": null
        },
        {
            "from": "human",
            "value": "Write an article based on this \"A man has been charged with murder and attempted murder after a woman and the man she was on a date with were stabbed at a restaurant in Sydney, Australia.\"",
            "weight": 0.0
        },
        {
            "from": "gpt",
            "value": "Title: Tragedy Strikes in Sydney: Victims Stabbed during Dinner Date, Suspect Charged with Murder and Attempted Murder\n\nIntroduction\n\nA seemingly ordinary evening took a tragic turn in Sydney, Australia when a romantic dinner date at a local restaurant ended in bloodshed, as a man allegedly attacked a couple, stabbing both victims. Local authorities have charged the suspect with murder and attempted murder in connection with the horrific crime. In this article, we delve into the chilling incident, the suspect, and the impact it has had on the community.\n\nThe Incident\n\nThe alarming event took place at a popular restaurant in Sydney, Australia. A woman and her date were enjoying their meal \u2013 blissfully unaware of the tragic circumstances that would soon unfold. As they shared conversation and indulged in their dinner, a man approached the couple and, without warning, began to stab them with a sharp object.\n\nPatrons at the restaurant scrambled for safety, as the assailant continued his violent rampage. Bystanders quickly called emergency services, who promptly arrived at the scene. Although both victims were rushed to a nearby hospital, the woman succumbed to her injuries. The man she was on a date with remained in critical condition, fighting for his life.\n\nThe Suspect\n\nInitial investigations revealed that the alleged attacker had no known connections to the victims \u2013 adding to the mystifying nature of this sudden and brutal assault. Witnesses reported that the assailant seemed to have no specific motive and appeared to carry out the act senselessly.\n\nFollowing a thorough investigation, local police identified and arrested the suspect. During the process, it was discovered that the alleged attacker had a history of criminal behavior and a troubled past, though it is unclear if this played a role in the tragic incident.\n\nAuthorities have formally charged the man with murder and attempted murder in connection with the heinous crime. He awaits a hearing to determine a trial date and, if convicted, could face a life sentence in prison.\n\nThe Community's Response\n\nThe shocking nature of the crime has left the residents of Sydney reeling, as they struggle to come to terms with the harrowing event. The restaurant where the attack occurred has since been closed, with a makeshift memorial being created outside to commemorate the victims.\n\nMany have questioned how such a vicious crime could happen in what is considered to be one of the safest cities in the world. This tragic event has spurred local officials to reassess current security measures and devise strategies to reinforce public safety. Additionally, discussions surrounding mental health and criminal rehabilitation have surfaced as residents seek to comprehend the actions of the alleged perpetrator and prevent similar incidents from occurring in the future.\n\nIn the wake of the stabbing, the community has banded together with an outpouring of grief and support for the victims and their families. Candlelight vigils have been held, and an online fundraising campaign is underway to assist the surviving victim with his medical expenses and recovery.\n\nConclusion\n\nThe tragic attack in Sydney serves as a chilling reminder that senseless acts of violence can happen anywhere and at any time. The community's response to this horrific and seemingly random act of brutality has been one of solidarity and determination to prevent such incidents in the future. As the case unfolds, the victims and their families remain in the hearts of the community, who are grieving the devastating loss of a life cut tragically short and supporting the recovering victim as he continues to endure this unimaginable ordeal.",
            "weight": 1.0
        }
    ]
}

Format the Data#

We format the data in much the same way as in the tinyllama notebook. However, there are a few differences to note.

  • The default chat template is different, using the <|im_start|> and <|im_end|> special tokens

  • the template did not, by default, add an <|endoftext|> token at the end of the chat, so we needed to do this manually. Without training on data including the <|endoftext|> token, at inference time, the model just keeps generating until it hits the token limit instead of stopping naturally after addressing the instruction.

Examine the chat template#

print(tokenizer.chat_template), print(tokenizer.default_chat_template)
No chat template is defined for this tokenizer - using a default chat template that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.
None
{% for message in messages %}{{'<|im_start|>' + message['role'] + '
' + message['content'] + '<|im_end|>' + '
'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant
' }}{% endif %}
(None, None)

There is no chat template defined for this tokenizer, so we’ll use the default, which is the ChatML format. In order to use the template, we first need to adjust the slimorca records to match the following format, with role and content instead of from and value keys, and system/assistant/user roles instead of system/gpt/human. The chat is still structured as a list of dictionaries. Here’s an example of a chat in the expected format:

chat = [
    {"role": "system", "content": "You are a helpful assistant and an expert at making coffee."},
    {"role": "user", "content": "How do I make coffee with a Chemex coffee maker?"},
    {"role": "assistant", "content": "To make coffee with a Chemex:\n1. Boil water to about 200°F (93°C).\n2. Place the Chemex filter in the top and rinse it with hot water to remove paper taste and warm the vessel. Discard the rinse water.\n3. Add coffee grounds to the filter. Use a medium-coarse grind, about 1 gram of coffee per 16 grams of water.\n4. Pour just enough hot water to saturate the grounds. Wait 30 seconds for the coffee to 'bloom'.\n5. Slowly pour the remaining water over the grounds in a circular motion. Aim for a total brew time of 3.5 to 4.5 minutes.\n6. Once brewing is complete, remove the filter and enjoy."}
]

Now we can apply the chat template and obtain a string-formatted chat that we can tokenize and train on. Note the lack of a token indicating the end of the string! We will need to add the tokenizer.eos_token to the end of the string manually. This tokenizer did not define a bos_token, so we will proceed without one.

print(tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False))
<|im_start|>system
You are a helpful assistant and an expert at making coffee.<|im_end|>
<|im_start|>user
How do I make coffee with a Chemex coffee maker?<|im_end|>
<|im_start|>assistant
To make coffee with a Chemex:
1. Boil water to about 200°F (93°C).
2. Place the Chemex filter in the top and rinse it with hot water to remove paper taste and warm the vessel. Discard the rinse water.
3. Add coffee grounds to the filter. Use a medium-coarse grind, about 1 gram of coffee per 16 grams of water.
4. Pour just enough hot water to saturate the grounds. Wait 30 seconds for the coffee to 'bloom'.
5. Slowly pour the remaining water over the grounds in a circular motion. Aim for a total brew time of 3.5 to 4.5 minutes.
6. Once brewing is complete, remove the filter and enjoy.<|im_end|>

Apply the template to the whole dataset#

Now we need to apply the template to the whole slimorca dataset. We will first convert the slimorca entries into the expected format, and then use tokenizer.apply_chat_template to apply the template.

We will also add the <|im_start|> and <|im_end|> special tokens to the tokenizer. The <|endoftext|> and <|padding|> tokens are already in the tokenizer’s vocabulary, so we do not need to add them manually.

Unlike the tinyllama notebook, we add the tokenizer.eos_token to the end of the string here. Without doing so, the model does not learn when to stop generating.

import torch

# Add the instruction tokens to the tokenizer
special_tokens = ["<|im_start|>", "<|im_end|>"]
# Adding special tokens to the tokenizer
tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

# Do not need to resize the model's input token embeddings matrix
# it is already larger than the vocabulary/large enough to accommodate
# the added tokens
# model.resize_token_embeddings(len(tokenizer))

def format_slimorca(ex):
    role_mapping = {"gpt": "assistant", "system": "system", "human": "user"}
    chat = [
        {"role": role_mapping[message["from"]], "content": message["value"]}
        for message in ex["conversations"]
    ]
    formatted_chat = tokenizer.apply_chat_template(
        chat,
        tokenize=False,  # Apply formatting but do not tokenize
        add_generation_prompt=False,
    ) + tokenizer.eos_token # add the end of sequence token

    # Tokenize using the standard tokenizer method
    tokenized_output = tokenizer(
        formatted_chat,
        add_special_tokens=False,  # apply_chat_template already added special tokens
        padding="max_length",  # pad to the specified length
        max_length=512,  # max length at which to truncate or to which to pad
        truncation=True,  # truncate to the specified length
    )

    return tokenized_output


# Map to the dataset
slimorca_tokenized = slimorca.map(format_slimorca, num_proc=16).remove_columns(
    "conversations"
)
slimorca_tokenized
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 517982
    })
})

Now let’s inspect a single example and make sure it corresponds to the format we expect.

# Inspect one example
print(tokenizer.decode(slimorca_tokenized["train"][25]['input_ids']))
<|im_start|>system
You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can. While performing the task think step-by-step and justify your steps.<|im_end|>
<|im_start|>user
Read this: From 1981 to 2010, the average annual precipitation measured at Seattle–Tacoma International Airport was 37.49 inches (952 mm). Annual precipitation has ranged from 23.78 in (604 mm) in 1952 to 55.14 in (1,401 mm) in 1950; for water year (October 1 – September 30) precipitation, the range is 23.16 in (588 mm) in 1976–77 to 51.82 in (1,316 mm) in 1996–97. Due to local variations in microclimate, Seattle also receives significantly lower precipitation than some other locations west of the Cascades. Around 80 mi (129 km) to the west, the Hoh Rain Forest in Olympic National Park on the western flank of the Olympic Mountains receives an annual average precipitation of 142 in (3.61 m). Sixty miles to the south of Seattle, the state capital Olympia, which is out of the Olympic Mountains' rain shadow, receives an annual average precipitation of 50 in (1,270 mm). The city of Bremerton, about 15 mi (24 km) west of downtown Seattle, receives 56.4 in (1,430 mm) of precipitation annually.

What is the average rainfall in Seattle? 
What is the answer? (If it cannot be answered, return "unanswerable")<|im_end|>
<|im_start|>assistant
The average annual precipitation measured at Seattle-Tacoma International Airport from 1981 to 2010 was 37.49 inches (952 mm).

The answer is 37.49 inches (952 mm).<|im_end|>
<|endoftext|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|>

Note the padding tokens at the end. The whole example was shorter than 512 tokens, so it was padded to reach 512 tokens.

Split the dataset into training and validation#

Here we also limit to a training subset of 10,000 examples. This is based on the LIMIT paper, which found that a small number of high-quality examples is sufficient for instruction-tuning. Under ideal circumstances, we would choose more domain-specific examples with a variety of different formats. Given that we are not tailoring this fine-tuning job for a specific domain, we will just choose 10,000 random examples from the SlimOrca dataset. We could almost certainly get by with fewer examples, especially if those examples were selected for quality and tailored to the specific tasks we want the model to succeed at.

from datasets import DatasetDict
from transformers import set_seed

set_seed(123)

slimorca_tokenized_split = slimorca_tokenized["train"].train_test_split(
    train_size=10000, test_size=1000
)

Now we will configure a collator. The collator is responsible for taking inputs, generating labels, and assembling the inputs into batches. See the data preprocessing notebook for more details.

Since we already padded/truncated the inputs to the same lengths, we don’t need anything special here. The DataCollatorForLanguageModeling collator will add labels to each entry. Importantly, the labels are the same as the inputs. The trainer handles shifting the labels; we do not need to implement any custom logic to align the labels.

from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=False,
)

Fine-tune the model#

Now that the data are ready, we can train the model using the Hugging Face Trainer. This part is essentially the same as in the tinyllama example.

Hyperparameters and Training Arguments#

At a high level: this is a fairly naive fine-tuning job. We aren’t trying to excel at a specific benchmark or task. Our main goal is simply equipping the model with the ability to respond to instructions and questions in an appropriate format. We attain this goal fairly easily with a variety of different hyperparameter configurations. That said, especially when training on a smaller subset of the data for multiple epochs, the results were fairly sensitive to learning rate. The default of 0.00005 was too high and resulted in overfitting.

  • We set auto_find_batch_size to True. The trainer will try multiple batch sizes, starting from the specified per_device_train_batch_size, and reduce the batch size if it encounters an OOM error.

  • We use gradient accumulation to simulate a larger batch size. Gradients are accumulated over multiple mini-batches of data (because we cannot use a very large batch size). The weights are only updated after the specified number of gradient accumulation steps.

from transformers import TrainingArguments, Trainer
import mlflow

# Define the training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=4,
    auto_find_batch_size=True,
    warmup_steps=1,
    weight_decay=0.01,
    logging_dir=LOG_DIR,
    logging_steps=5,  # Log every 5 steps
    evaluation_strategy="epoch",
    #eval_steps=100,
    lr_scheduler_type="linear",
    bf16=True,
    gradient_accumulation_steps=8,
    gradient_checkpointing=False,
    save_steps=10000,
    learning_rate=8.5e-6
)

training_args.set_logging(report_to=["mlflow"], steps=5, level="info")


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=slimorca_tokenized_split["train"],
    eval_dataset=slimorca_tokenized_split["test"],
    data_collator=data_collator,
)

# Start training and track with MLflow
with mlflow.start_run(log_system_metrics=True):
    trainer.evaluate() # eval before starting tuning
    trainer.train()
    mlflow.log_params(training_args.to_dict())

trainer.save_model(OUTPUT_DIR + "/final")

Load the Fine-Tuned Model Checkpoint and Run some Examples#

Here are some example outputs from the fine-tuned olmo-1b model (including the prompts and instruction tokens):

<|im_start|>system
You are a food critic with strong opinions.<|im_end|>
<|im_start|>user
What style of bagel is better, Montreal or New York? Answer briefly.<|im_end|>
<|im_start|>assistant
Montreal's bagels tend to be more flavorful and have softer texture than their New York counterparts. This difference in taste can vary depending on the type of flour used for baking them as well as how they're prepared (e.g., boiled water vs steamed). However, both styles offer unique flavors that cater to different tastes while maintaining consistency across regions. In general, it would depend on personal preference which one you prefer over another based on your own preferences!<|im_end|>
<|endoftext|>

Instruction tuning on general instruction/response data without domain specificity cannot compensate for the fact that this is a 1 billion parameter model that does not know how to make coffee.

<|im_start|>system
You are a helpful assistant and an expert at making coffee.<|im_end|>
<|im_start|>user
Tell me how to make coffee with a Chemex coffee maker.<|im_end|>
<|im_start|>assistant
Making coffee using the Chemex Coffee Maker is easy, but it's important to follow these steps: 

1. Choose your beans – Before you start brewing, choose the right type of coffee bean for this particular brewed coffee. You can find different types in various grocery stores or online shopping platforms like Amazon. Make sure that they're fresh (preferably within 2-3 days) because stale coffee will not taste good when brewed on a regular basis.
2. Prepare the brew – Fill up all necessary measuring scoops into the filter basket before adding ground coffee. This ensures proper extraction by ensuring there isn't too much watery grounds left behind after filtering out excess liquid from the grinds. If possible, use a scale instead of weighing directly onto the countertop since scales tend to be more accurate than hand measurements.
3. Brew slowly – Pour about 1/4 cup (60 ml) of freshly brewed coffee per serving. The amount may vary depending on personal preference; some people prefer less flavorful coffees while others enjoy stronger flavors. It's also essential to remember that if you have multiple cups of coffee served simultaneously during one session, then each person should get their own separate pot so as not to overpower everyone else!
4. Add hot water – After pouring enough coffee through the filter, add warm water back into the machine until the desired temperature has been reached without any lumps present. Be careful not to pour boiling water straight down the drain line, which could potentially damage the valve mechanism inside the unit due to heat buildup caused by excessive pressure generated by high temperatures. Instead, let the system cool off slightly first before turning on steam production mode again.
5. Close the lid and serve immediately – Once everything looks clean and ready, close the door securely and turn on the power switch located underneath the handle. Wait patiently for approximately 10 minutes for the coffee to steepen naturally, allowing time for air bubbles to form around the top cap area where the milk might settle. Then gently lift the cover away once steeping process is complete. Enjoying your deliciously brewed coffee now! Remember, always check the manufacturer's instructions regarding recommended storage times for both preheated and cold beverages such as keeping them refrigerated between 0°C - 5°F (-32°C), otherwise known as "cold" storage period.<|im_end|>

A Note on Weight Tying & “Some weights of OLMoForCausalLM were not initialized from the model checkpoint” errors#

When loading the fine-tuned model, you might see the following error message:

Some weights of OLMoForCausalLM were not initialized from the model checkpoint at <checkpoint_path> and are newly initialized: [‘model.transformer.ff_out.weight’] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

That’s certainly troubling. But a little bit of digging into the OLMo Codebase clears things up.

At the bottom of modeling_olmo.py, we find:

    def get_output_embeddings(self):
        if self.config.weight_tying:
            return self.model.transformer.wte
        else:
            return self.model.transformer.ff_out

    def set_output_embeddings(self, value: torch.nn.Module):
        if self.config.weight_tying:
            self.model.transformer.wte = value
        else:
            self.model.transformer.ff_out = value

    def tie_weights(self):
        if self.config.weight_tying:
            self.model.transformer.ff_out = self.model.transformer.wte

While OLMo 7B does not use weight tying, OLMo 1B does. You can confirm in the config here. So it looks like we should expect that model.model.transformer.ff_out == model.model.transformer.wte. And, indeed:

model.model.transformer.ff_out == model.model.transformer.wte

returns True. So the ff_out layer is tied to the wte layer and is not a new, randomly-initialized layer. We can take a closer look:

model.model.transformer.wte.state_dict()
OrderedDict([('weight',
              tensor([[-0.0229, -0.0197, -0.0571,  ...,  0.0493, -0.0400,  0.0067],
                      [ 0.0405,  0.0312, -0.0249,  ..., -0.0137, -0.0300,  0.0019],
                      [ 0.0098, -0.0040, -0.0091,  ..., -0.0248, -0.0608,  0.0151],
                      ...,
                      [ 0.0332,  0.0046, -0.0464,  ..., -0.0028,  0.0082, -0.0275],
                      [ 0.0315, -0.0101, -0.0237,  ...,  0.0040,  0.0247,  0.0217],
                      [ 0.0026, -0.0055, -0.0048,  ..., -0.0126, -0.0339,  0.0114]],
                     device='cuda:0', dtype=torch.bfloat16))])
model.model.transformer.ff_out.state_dict()
OrderedDict([('weight',
              tensor([[-0.0229, -0.0197, -0.0571,  ...,  0.0493, -0.0400,  0.0067],
                      [ 0.0405,  0.0312, -0.0249,  ..., -0.0137, -0.0300,  0.0019],
                      [ 0.0098, -0.0040, -0.0091,  ..., -0.0248, -0.0608,  0.0151],
                      ...,
                      [ 0.0332,  0.0046, -0.0464,  ..., -0.0028,  0.0082, -0.0275],
                      [ 0.0315, -0.0101, -0.0237,  ...,  0.0040,  0.0247,  0.0217],
                      [ 0.0026, -0.0055, -0.0048,  ..., -0.0126, -0.0339,  0.0114]],
                     device='cuda:0', dtype=torch.bfloat16))])

In short…this is nothing to worry about. Why do we get the warning in the first place? It originates from the Hugging Face Accelerate library, which checks for tied weights by looking for distinct modules with shared weights. But OLMo 1B’s weight-tying approach uses self.model.transformer.ff_out = self.model.transformer.wte, so the modules themselves, and not just their weights, are the same. So the check from the Accelerate library fails to identify the tied weights and returns the warning.