< All tutorials

How to fine-tune BERT using HuggingFace

Joe Cummings

21 July, 2021 (Last Updated: 23 July, 2021)



Motivation

BERT is built upon a machine learning architecture called a Transformer and Transformers are fascinating. Everyone from those just flirting with NLP to those on the cutting edge will have to use a Transformer-based model at some point in their lives.

I’d highly recommend reading through this entire post as it adds color to the model you’ll be building, but if you just want the TL;DR, you can skip to the tutorial part now.

Background

The Transformer architecture was introduced in the paper “Attention Is All You Need” in 2017 and has since been cited over 24k times. The Transformer has proven to be both superior in quality and faster to train by virtue of relying solely on attention mechanisms - doing away with cumbersome convolution and recurrence. I’d highly recommend reading more about it before going further with BERT - see these amazing resources here and here.

Extra challenge: please if you have an intuitive way to explain positional encoding because it still trips me up sometimes.

Researchers jumped at the chance to build upon the Transformer architecture and soon the world had The Allen Institute’s ELMo and OpenAI’s GPT/GPT-2. For this tutorial, we’ll look at Google’s BERT, which was introduced in 2019.

So, what sets BERT apart? The answer lies in the name: Bidirectional Encoder Representations from Transformers. Previous models, like GPT, encoded sequences using a left-to-right architecture, meaning that every token can only “see” the tokens that came before it. This is sub-optimal for tasks like question answering where it is extremely helpful to integrate context from the entire sequence. BERT’s architecture enables deep bidirectional representations.

Comparison of self-attention - what BERT uses - and masked attention - what GPT uses (source)

You may be thinking at this point: “Okay, why mask the attention mechanism? Why not just integrate the context from the entire sequence from the start?”. The answer is that strictly bidirectional conditioning would allow each token in a sequence to essentially “see itself” and the output would be trivially predicted. Imagine I gave you the following sentence: “The dog went to the park” and asked you to “predict” what word came after “dog”. Since you have the entire sentence as context, you know that “went” immediately succeeds “dog”. While this is a slight oversimplification, it should convey the general idea. The diagram below also helps visualize the difference between these langauge modeling methods.

Encoding styles of BERT, GPT, and ELMo. ELMo does a shallow concatenation of a left-to-right encoding and a right-to-left encoding.(source)

So, if bidirectional encoding is impossible, how is BERT doing it? BERT introduces something called a “masked language model” (MLM), but you might also see this referrred to as a cloze task. In pre-training, 15% of all tokens are replaced with a special [MASK] token or a random token.

The dog went to the park. -> The dog [MASK] to the park.
                          -> The dog banana to the park.

Example of how the sequence “The dog went to the park” would be masked in pre-training of BERT.

The model then is tasked with predicting the correct missing token. So rather than processing the left context of a sequence and trying predict the next token, BERT has to learn how to predict at random spots in the sentence.

While MLM models the relationship between tokens in a sequence, BERT is also trained on with something called “next sentence prediction”, which models the relationships between sentences. This is very useful for question answering, summarization, and multiple-choice tasks. The data is encoded as shown below.

A: The dog went to the park. 
B: It rolled around in the grass.
Classification: IsNext
---
A. The dog went to the park.
B: The crow caws at midnight.
Classification: NotNext

These two tasks were trained with 800M words from the BooksCorpus and the entirety of English Wikipedia, made up of over 2,500M words. Together, these make up the amazing model that is BERT.

Enough background - let’s get to using BERT!

Fine-tuning BERT

There are many different ways in which we could load the BERT model and fine-tune it on a downstream task. Some excellent tutorials with different frameworks can be found below:

For this tutorial, we’ll be using the popular Transformers library from HuggingFace to fine-tune BERT on a sentiment analysis task. Despite the slightly silly name, HuggingFace is a fantastic resource for those in NLP engineering. To start, let’s create a conda environment and install the HuggingFace library. To support the HuggingFace library, you’ll also need to download PyTorch.

1. Setup

conda create env --name fine-tune-bert python=3.7
conda activate fine-tune-bert
pip install transformers
pip install torch

In addition, we’ll need to download HuggingFace’s Datasets package, which offers easy access to many benchmark datasets. Later on, we’re also going to want to specify our own evaluation metric and for that, we need to use scikit-learn’s library, so go ahead and install that, too.

pip install datasets
pip install sklearn

Now that we’ve got everything installed, let’s start the actual work. First, let’s import the necessary functions and objects.

from datasets import load_dataset, load_metric
from transformers import (
    BertForSequenceClassification,
    BertTokenizer,
    Trainer,
    TrainingArguments,
)

You’ll notice that we are importing a version of BERT called BertForSequenceClassification. HuggingFace offers several versions of the BERT model including a base BertModel, BertLMHeadMoel, BertForPretraining, BertForMaskedLM, BertForNextSentencePrediction, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering, and more. The only real difference between a lot of these is the extra layer on top of the pretrained model which is task-specific. You can find all of those models and their specifications here. We’re using BertForSequenceClassification because we are trying to classify a sequence of text with a certain emotion/sentiment.

The dataset we’re using is called "emotion" on HuggingFace’s Datasets catalog and consists of 20k Tweets labeled for one of 8 emotions: anger, anticipation, disgust, fear, joy, sadness, surprise, and trust. You can read more details about how the data was collected, different baseline experiments, and the data distribution from the paper. So let’s load in the emotions dataset.

emo_dataset = load_dataset("emotion")  # It really is that easy.

Take a peek at the first 5 items in the training data and see what we have.

>>> emo_dataset["train"][:5]
{
    'text': [
        'i didnt feel humiliated',
        'i can go from feeling so hopeless to so damned hopeful just from being around someone who cares and is awake',
        'im grabbing a minute to post i feel greedy wrong',
        'i am ever feeling nostalgic about the fireplace i will know that it is still on the property',
        'i am feeling grouchy'
    ], 
    'label': [0, 0, 3, 2, 3]
}

It appears the text has already been lowercased (good!), links and hastags have been removed, and contractions are standardized. It’s also good to double check that the data makes sense. The labels have already been converted into a numeric value, with each number corresponding to an emotion. For example, 0 is “sadness”, 3 is “anger”, and 2 is “love”.

2. Preprocessing

Now that we have some data, we need to do some preprocessing to it so that BERT can understand and thankfully, HuggingFace provides a helpful BertTokenizer that takes care of this for us.

We can load the BERT Tokenizer from a pretrained model (they come together).

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

And let’s see how this is encoded!

>>> tokenizer(["The dog went to the park."], padding="max_length", truncate=True)
{
    'input_ids': [[101, 1996, 3899, 2253, 2000, 1996, 2380, 1012, 102, 0, ..., 0]], 
    'token_type_ids': [[0, ..., 0]], 
    'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0, ..., 0]]
}

At this point, you might be thinking what the hell am I looking at? Well I’ve spared your eyes by truncating the number of zeros shown in the output, but the length of each item in the returned dictionary is 512, which is the maximum length that BERT can accept. To some, this might be confusing - why do we need every input sequence to be the same size? The answer is efficiency. Linear combinations are much faster than normal multiplication and for those to be possible, all vectors need to be of the same length.

So the mapping from above looks like the following:

token input_id
[CLS] 101
the 1996
dog 3899
went 4
to 2000
park 2380
. 1012
[SEP] 102

Keep in mind while debugging, that you may see more input_ids than original tokens. That’s because BERT tokenizes using WordPiece which can split some words into two or three different tokens.

So, how can we apply this tokenize function across all text labels? One way would simply be to iterate programmatically over every entry in the dataset and convert the text like so:

for example in emo_dataset["train"]:
    tokenized_text = tokenizer(example["text"])

This is a slow process and we do have a better option. HuggingFace provides us with a useful map function on the Dataset object.

def tokenize_go_emotion(example):
    return tokenizer(example["text"], padding="max_length", truncation=True)

tokenized_data = dataset.map(tokenize_go_emotion, batched=True)

In my experience, Dataset.map runs ~200ms faster than linear iteration and automatically caches the result so that each subsequent call takes a fraction of the time to complete.

Bonus question: why wouldn’t we want to just define our function as lambda x: tokenizer(x["text"]) and save a couple lines of code? (See answer at the end of the tutorial).

As mentioned, emotion is a rather large dataset and I don’t know about you, but I’m trying not to beat the shit out of my already overworked computer. So let’s shuffle the data and grab a subset of the examples.

small_train_dataset = tokenized_data["train"].shuffle(seed=42).select(range(1000))
small_val_dataset = tokenized_data["validation"].shuffle(seed=42).select(range(100))

Now for the fun part - let’s build this model! First, we load BERT from a pretrained HuggingFace location and specify how many labels BERT will have.

model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", num_labels=8
)

3. Training & Evaluating

Second, we load the training arguments for the model (this could also be known as the config). TrainingArguments has a ton of parameters so you can check those out here. For our purposes, we only need to specify the output directory, the evaluation strategy (when will we evaluate the results), and the number of epochs to run.

training_args = TrainingArguments(
    "bert_trainer",
    evaluation_strategy="epoch",
    num_train_epochs=5,
)

Finally, we’re ready to train. We set up an abstract Trainer class and give it our model, arguments, the training dataset, and the validation dataset to evaluate on. Calling the trainer.train() method (not surprisingly) kicks of the model fine-tuning.

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_val_dataset,
)

trainer.train()

The first output will probably look something like…

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

But don’t freak out! This is what we expect because we are randomly initializing the weights of the last head for our classification task.

If you have access to a GPU, HuggingFace will automatically find the device and push most calculations to that. I was able to run the entire dataset on a single Nvidia GeForce GTX 1080 GPU in 77 minutes with an evalutation micro f1 score of 94%.

If you’re unfamiliar with the F1-scoring metric, you can read more about it here and why it can be a better metric than accuracy.

I recognize that not all people have access to such compute power, so for comparison, I ran the fine-tuning on an Intel(R) Xeon(R) CPU E5-2630 v4 @ 2.20GHz. Unfortunately, the amount of time this would take is somewhat absurd, so I scaled back the size of our dataset. Training on 1000 examples, I fine-tuned BERT in 1020 minutes with an evaluation micro f1 score of 86%.

Conclusion

In this tutorial, we learned about the incredible Transformer model called BERT and how to quickly and easily fine-tune it on a downstream task. With this knowledge, you can go forth and build many a NLP application.

Bonus question answer: the pickle module, which is the default serializer in Python, does not serialize or deserialize code, e.g. lambda functions. It only serializes the names of classes/methods/functions. Therefore, if you want to save your model to use again, you cannot use an anonymous function.

You can find all the code for this tutorial on my Github. If you have any comments, questions, or corrections, feel free to .

Thanks to:

Dan Knight for his feedback and encouragement.