🧼 Clean labels using your model loss

TL;DR

  1. A simple technique for error analysis is introduced: using model loss to find potential training data errors.

  2. The technique is shown using a fine-tuned text classifier from the Hugging Face Hub on the AG News dataset.

  3. Using Rubrix, we verify more than 100 mislabelled examples on the training set of this well-known NLP benchmark.

  4. This trick is useful during model training with small and noisy datasets.

  5. This trick is complementary with other “data-centric” ML methods such as cleanlab (see the Rubrix tutorial on cleanlab).

Introduction

This tutorial explains a simple trick for finding potential errors in training data: using your model loss to identify label errors or ambiguous examples. This trick is inspired by the following tweet:

The technique is really simple: if you are training a model with a training set, train your model, and you apply your model to the training set to compute the loss for each example in the training set. If you sort your dataset examples by loss, examples with the highest loss are the most ambiguous and difficult to learn.

This very simple technique can be used for error analysis during model development (e.g., identifying tokenization problems), but it turns out is also a really simple technique for cleaning up your training data, during model development or after training data collection activities.

In this tutorial, we’ll use this technique with a well-known text classification benchmark, the AG News dataset. After computing the losses, we’ll use Rubrix to analyse the highest loss examples. In less than 10 minutes, we manually check and relabel the first 100 examples. In fact, the first 100 examples with the highest loss, are all incorrect in the original training set. If we visually inspect further examples, we still find label errors in the top 500 examples.

Ingredients

  • A model fine-tuned with the AG News dataset (you could train your own model if you wish).

  • The AG News train split (the same trick could and should be applied to validation and test splits).

  • Rubrix for logging, exploring, and relabeling wrong examples.

Steps

  1. Load the fine-tuned model and the AG News train split.

  2. Compute the loss for each example and sort examples by descending loss.

  3. Log the first 500 example into a Rubrix dataset. We provide you with the processed dataset if you want to skip the first two steps.

  4. Use Rubrix webapp for inspecting the examples ordered by loss. In the following video, we show you the full process for manually correcting the first 100 examples (all incorrect in the original dataset, the original video is 8 minutes long):

Why it’s important

  1. Machine learning models are only as good as the data they’re trained on. Almost all training data source can be considered “noisy” (e.g., crowd-workers, annotator errors, weak supervision sources, data augmentation, etc.)

  2. With this simple technique we’re able to find more than 100 label errors on a widely-used benchmark in less than 10 minutes. Your dataset will probably be noisier.

  3. With advanced model architectures widely-available, managing, cleaning, and curating data is becoming a key step for making robust ML applications. A good summary of the current situation can be found in the website of the Data-centric AI NeurIPS Workshop.

  4. This simple trick can be used accross the whole ML life-cyle and not only for finding label errors. With this trick you can improve data preprocessing, tokenization, and even your model architecture.

Setup Rubrix

Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the ⭐ Github repository.

If you have not installed and launched Rubrix, check the Setup and Installation guide.

Once installed, you only need to import Rubrix:

[3]:
import rubrix as rb

Tutorial dependencies

We’ll install the Hugging Face libraries transformers and datasets, as well as PyTorch, for the model and data set we’ll use in the next steps.

[ ]:
!pip install transformers datasets torch

1. Load the fine-tuned model and the training dataset

[ ]:
import torch

from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers.data.data_collator import DataCollatorWithPadding

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[ ]:
# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("andi611/distilbert-base-uncased-agnews")
model = AutoModelForSequenceClassification.from_pretrained("andi611/distilbert-base-uncased-agnews")

# load the training split
from datasets import load_dataset
ds = load_dataset('ag_news', split='train')
[ ]:
# tokenize and encode the training set
def tokenize_and_encode(batch):
    return tokenizer(batch['text'], truncation=True)

ds_enc = ds.map(tokenize_and_encode, batched=True)

2. Computing the loss

The following code will compute the loss for each example using our trained model. This process is taken from the very well-explained blog post by Lewis Tunstall: “Using data collators for training and error analysis”, where he explains this process for error analysis during model training.

In our case, we instantiate a data collator directly, while he uses the Data Collator from the Trainer directly.

[ ]:
# create the data collator for inference
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
[ ]:
# function to compute the loss example-wise
def loss_per_example(batch):
    batch = data_collator(batch)
    input_ids = torch.tensor(batch["input_ids"], device=device)
    attention_mask = torch.tensor(batch["attention_mask"], device=device)
    labels = torch.tensor(batch["labels"], device=device)

    with torch.no_grad():
        output = model(input_ids, attention_mask)
        batch["predicted_label"] = torch.argmax(output.logits, axis=1)
        # compute the probabilities for logging them into Rubrix
        batch["predicted_probas"] = torch.nn.functional.softmax(output.logits, dim=0)

    # don't reduce the loss (return the loss for each example)
    loss = torch.nn.functional.cross_entropy(output.logits, labels, reduction="none")
    batch["loss"] = loss

    # datasets complains with numpy dtypes, let's use Python lists
    for k, v in batch.items():
        batch[k] = v.cpu().numpy().tolist()

    return batch
[ ]:
import pandas as pd

losses_ds = ds_enc.remove_columns("text").map(loss_per_example, batched=True, batch_size=32)

# turn the dataset into a Pandas dataframe, sort by descending loss and visualize the top examples.
pd.set_option("display.max_colwidth", None)

losses_ds.set_format('pandas')
losses_df = losses_ds[:][['label', 'predicted_label', 'loss', 'predicted_probas']]

# add the text column removed by the trainer
losses_df['text'] = ds_enc['text']
losses_df.sort_values("loss", ascending=False).head(10)
label predicted_label loss predicted_probas text
44984 1 0 8.833023 [0.06412869691848755, 7.090532017173246e-05, 0.00019675122166518122, 0.0002370826987316832] Baghdad blasts kills at least 16 Insurgents have detonated two bombs near a convoy of US military vehicles in southern Baghdad, killing at least 16 people, Iraqi police say.
101562 1 0 8.781285 [0.12395327538251877, 9.289286026614718e-06, 0.0001785584754543379, 0.0007945793331600726] Immoral, unjust, oppressive dictatorship. . . and then there #39;s <b>...</b> ROBERT MUGABES Government is pushing through legislation designed to prevent human rights organisations from operating in Zimbabwe.
31564 1 2 8.772168 [0.00016983140085358173, 8.863882612786256e-06, 0.18702593445777893, 0.00025946463574655354] Ford to Cut 1,150 Jobs At British Jaguar Unit Ford Motor Co. announced Friday that it would eliminate 1,150 jobs in England to streamline its Jaguar Cars Ltd. unit, where weak sales have failed to offset spending on new products and other parts of the business.
41247 1 0 8.751480 [0.2929899990558624, 7.849136454751715e-05, 0.00034211069578304887, 4.463219011086039e-05] Palestinian gunmen kidnap CNN producer GAZA CITY, Gaza Strip -- Palestinian gunmen abducted a CNN producer in Gaza City on Monday, the network said. The network said Riyadh Ali was taken away at gunpoint from a CNN van.
44961 1 0 8.740394 [0.06420651078224182, 7.788064249325544e-05, 0.0001824614155339077, 0.0002348265261389315] Bomb Blasts in Baghdad Kill at Least 35, Wound 120 Insurgents detonated three car bombs near a US military convoy in southern Baghdad on Thursday, killing at least 35 people and wounding around 120, many of them children, officials and doctors said.
75216 1 0 8.735966 [0.13383473455905914, 1.837693343986757e-05, 0.00017987379396799952, 0.00036031895433552563] Marine Wives Rally A group of Marine wives are running for the family of a Marine Corps officer who was killed in Iraq.
31229 1 2 8.729340 [5.088283069198951e-05, 2.4471093638567254e-05, 0.18256260454654694, 0.00033902408904396] Auto Stocks Fall Despite Ford Outlook Despite a strong profit outlook from Ford Motor Co., shares of automotive stocks moved mostly lower Friday on concerns sales for the industry might not be as strong as previously expected.
19737 3 1 8.545797 [4.129256194573827e-05, 0.1872873306274414, 4.638762402464636e-05, 0.00010757221753010526] Mladin Release From Road Atlanta Australia #39;s Mat Mladin completed a winning double at the penultimate round of this year #39;s American AMA Chevrolet Superbike Championship after taking
60726 2 0 8.437369 [0.5235446095466614, 4.4463453377829865e-05, 3.5171411582268775e-05, 8.480428368784487e-05] Suicide Bombings Kill 10 in Green Zone Insurgents hand-carried explosives into the most fortified section of Baghdad Thursday and detonated them within seconds of each other, killing 10 people and wounding 20.
28307 3 1 8.386065 [0.00018589739920571446, 0.42903241515159607, 2.5073826691368595e-05, 3.97983385482803e-05] Lightning Strike Injures 40 on Texas Field (AP) AP - About 40 players and coaches with the Grapeland High School football team in East Texas were injured, two of them critically, when lightning struck near their practice field Tuesday evening, authorities said.
[2]:
# save this to a file for further analysis
#losses_df.to_json("agnews_train_loss.json", orient="records", lines=True)

While using Pandas and Jupyter notebooks is useful for initial inspection, and programmatic analysis. If you want to quickly explore the examples, relabel them, and share them with other project members, Rubrix provides you with a straight-forward way for doing this. Let’s see how.

3. Log high loss examples into Rubrix

Using the amazing Hugging Face Hub we’ve shared the resulting dataset, which you can find here.

[7]:
# if you have skipped the first two steps you can load the dataset here:
#losses_df = pd.read_json("agnews_train_loss.jsonl", lines=True, orient="records")
[ ]:
# creates a Text classification record for logging into Rubrix
def make_record(row):

    return rb.TextClassificationRecord(
        inputs={"text": row.text},
        # this is the "gold" label in the original dataset
        annotation=[(ds_enc.features['label'].names[row.label])],
        # this is the prediction together with its probability
        prediction=[(ds_enc.features['label'].names[row.predicted_label], row.predicted_probas[row.predicted_label])],
        # metadata fields can be used for sorting and filtering, here we log the loss
        metadata={"loss": row.loss},
        # who makes the prediction
        prediction_agent="andi611/distilbert-base-uncased-agnews",
        # source of the gold label
        annotation_agent="ag_news_benchmark"
    )
[ ]:
# if you want to log the full dataset remove the indexing
top_losses = losses_df.sort_values("loss", ascending=False)[0:499]

# build Rubrix records
records = top_losses.apply(make_record, axis=1)
[ ]:
rb.log(records, name="ag_news_error_analysis")

4. Using Rubrix Webapp for inspection and relabeling

In this step, we have a Rubrix Dataset available for exploration and annotation. A useful feature for this use case is Sorting. With Rubrix you can sort your examples by combining different fields, both from the standard fields (such as score) and custom fields (via the metadata fields). In this case, we’ve logged the loss so we can order our training examples by loss in descending order (showing higher loss examples first).

For preparing this tutorial, we have manually checked and relabelled the first 100 examples. You can watch the full session (with high-speed during the last part) here. In the video we use Rubrix annotation mode to change the label of mislabelled examples (the first label correspond to the original “gold” label and the second corresponds to the predictions of the model).

We’ve also randomly checked the next 400 examples finding many potential errors. If you are interested you can repeat our experiment or even help validate the next 100 examples, we’d love to know about your results! We plan to share the 100 relabeled examples with the community in the Hugging Face Hub.

Next steps

If you are interested in the topic of training data curation and denoising datasets, check out the tutorial for using Rubrix with cleanlab.

🙋‍♀️ Join the Rubrix community! A good place to start is the discussion forum.

⭐ Rubrix Github repo to stay updated.