🧼 Clean labels using your model loss#

In this tutorial, we will learn to introduce a simple technique for error analysis, using model loss to find potential training data errors.

  • 🤗 This technique is shown using a fine-tuned text classifier from the Hugging Face Hub on the AG News dataset.

  • ✅ Using Rubrix, we will verify more than 50 mislabelled examples on the training set of this well-known NLP benchmark.

  • 💥 This trick is useful for model training with small and noisy datasets.

  • 👥 This trick is complementary with other “data-centric” ML methods such as cleanlab (see this Rubrix tutorial).

Introduction#

This tutorial explains a simple trick you can leverage with Rubrix for finding potential errors in training data: using your model loss to identify label errors or ambiguous examples. This trick is not new (those who’ve worked with fastai know how useful the plot_top_losses method is). Even Andrej Karpathy tweeted about this some time ago:

The technique is really simple: if you are training a model with a training set, train your model, and you apply your model to the training set to compute the loss for each example in the training set. If you sort your dataset examples by loss, examples with the highest loss are the most ambiguous and difficult to learn.

This technique can be used for error analysis during model development (e.g., identifying tokenization problems), but it turns out is also a really simple technique for cleaning up your training data, during model development or after training data collection activities.

In this tutorial, we’ll use this technique with a well-known text classification benchmark, the AG News dataset. After computing the losses, we’ll use Rubrix to analyse the highest loss examples. In less than 5 minutes, we manually check and relabel the first 50 examples. In fact, the first 50 examples with the highest loss, are all incorrect in the original training set. If we visually inspect further examples, we still find label errors in the top 500 examples.

Why it’s important#

  1. Machine learning models are only as good as the data they’re trained on. Almost all training data source can be considered “noisy” (e.g., crowd-workers, annotator errors, weak supervision sources, data augmentation, etc.)

  2. With this simple technique we’re able to find more than 50 label errors on a widely-used benchmark in less than 5 minutes (your dataset will probably be noisier!).

  3. With advanced model architectures widely-available, managing, cleaning, and curating data is becoming a key step for making robust ML applications. A good summary of the current situation can be found in the website of the Data-centric AI NeurIPS Workshop.

  4. This simple trick can be used across the whole ML lifecycle and not only for finding label errors. With this trick you can improve data preprocessing, tokenization, and even your model architecture.

Setup#

Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the ⭐ Github repository.

If you have not installed and launched Rubrix yet, check the Setup and Installation guide.

For this tutorial we will also need the third party libraries transformers and datasets, as well as PyTorch, which can be installed via pip:

[ ]:
%pip install transformers datasets torch -qqq

Preliminaries#

  1. A model fine-tuned with the AG News dataset (you could train your own model if you wish).

  2. The AG News train split (the same trick could and should be applied to validation and test splits).

  3. Rubrix for logging, exploring, and relabeling wrong examples (we provide a pre-computed datasets so feel free to skip to this step)

1. Load the fine-tuned model and the training dataset#

Now, we will load the AG News dataset. But first, we need to define and set the device, the model and the tokenizer:

[ ]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
[ ]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("andi611/distilbert-base-uncased-ner-agnews")
model = AutoModelForSequenceClassification.from_pretrained("andi611/distilbert-base-uncased-ner-agnews")
[ ]:
from datasets import load_dataset

# load the training split
ds = load_dataset('ag_news', split='train')
[ ]:
# tokenize and encode the training set
def tokenize_and_encode(batch):
    return tokenizer(batch['text'], truncation=True)

ds_enc = ds.map(tokenize_and_encode, batched=True)

2. Computing the loss#

The following code will compute the loss for each example using our trained model. This process is taken from the very well-explained blog post by Lewis Tunstall: “Using data collators for training and error analysis”, where he explains this process for error analysis during model training.

In our case, we instantiate a data collator directly, while he uses the Data Collator from the Trainer directly:

[ ]:
from transformers.data.data_collator import DataCollatorWithPadding

# create the data collator for inference
data_collator = DataCollatorWithPadding(tokenizer, padding=True)
[ ]:
# function to compute the loss example-wise
def loss_per_example(batch):
    batch = data_collator(batch)
    input_ids = torch.tensor(batch["input_ids"], device=device)
    attention_mask = torch.tensor(batch["attention_mask"], device=device)
    labels = torch.tensor(batch["labels"], device=device)

    with torch.no_grad():
        output = model(input_ids, attention_mask)
        batch["predicted_label"] = torch.argmax(output.logits, axis=1)
        # compute the probabilities for logging them into Rubrix
        batch["predicted_probas"] = torch.nn.functional.softmax(output.logits, dim=0)

    # don't reduce the loss (return the loss for each example)
    loss = torch.nn.functional.cross_entropy(output.logits, labels, reduction="none")
    batch["loss"] = loss

    # datasets complains with numpy dtypes, let's use Python lists
    for k, v in batch.items():
        batch[k] = v.cpu().numpy().tolist()

    return batch

Now, it is time to turn the dataset into a Pandas dataframe and sort this dataset by descending loss:

[ ]:
import pandas as pd

losses_ds = ds_enc.remove_columns("text").map(loss_per_example, batched=True, batch_size=32)

# turn the dataset into a Pandas dataframe, sort by descending loss and visualize the top examples.
pd.set_option("display.max_colwidth", None)

losses_ds.set_format('pandas')
losses_df = losses_ds[:][['label', 'predicted_label', 'loss', 'predicted_probas']]

# add the text column removed by the trainer
losses_df['text'] = ds_enc['text']
losses_df.sort_values("loss", ascending=False).head(10)
label predicted_label loss predicted_probas text
44984 1 0 8.833023 [0.06412869691848755, 7.090532017173246e-05, 0.00019675122166518122, 0.0002370826987316832] Baghdad blasts kills at least 16 Insurgents have detonated two bombs near a convoy of US military vehicles in southern Baghdad, killing at least 16 people, Iraqi police say.
101562 1 0 8.781285 [0.12395327538251877, 9.289286026614718e-06, 0.0001785584754543379, 0.0007945793331600726] Immoral, unjust, oppressive dictatorship. . . and then there #39;s <b>...</b> ROBERT MUGABES Government is pushing through legislation designed to prevent human rights organisations from operating in Zimbabwe.
31564 1 2 8.772168 [0.00016983140085358173, 8.863882612786256e-06, 0.18702593445777893, 0.00025946463574655354] Ford to Cut 1,150 Jobs At British Jaguar Unit Ford Motor Co. announced Friday that it would eliminate 1,150 jobs in England to streamline its Jaguar Cars Ltd. unit, where weak sales have failed to offset spending on new products and other parts of the business.
41247 1 0 8.751480 [0.2929899990558624, 7.849136454751715e-05, 0.00034211069578304887, 4.463219011086039e-05] Palestinian gunmen kidnap CNN producer GAZA CITY, Gaza Strip -- Palestinian gunmen abducted a CNN producer in Gaza City on Monday, the network said. The network said Riyadh Ali was taken away at gunpoint from a CNN van.
44961 1 0 8.740394 [0.06420651078224182, 7.788064249325544e-05, 0.0001824614155339077, 0.0002348265261389315] Bomb Blasts in Baghdad Kill at Least 35, Wound 120 Insurgents detonated three car bombs near a US military convoy in southern Baghdad on Thursday, killing at least 35 people and wounding around 120, many of them children, officials and doctors said.
75216 1 0 8.735966 [0.13383473455905914, 1.837693343986757e-05, 0.00017987379396799952, 0.00036031895433552563] Marine Wives Rally A group of Marine wives are running for the family of a Marine Corps officer who was killed in Iraq.
31229 1 2 8.729340 [5.088283069198951e-05, 2.4471093638567254e-05, 0.18256260454654694, 0.00033902408904396] Auto Stocks Fall Despite Ford Outlook Despite a strong profit outlook from Ford Motor Co., shares of automotive stocks moved mostly lower Friday on concerns sales for the industry might not be as strong as previously expected.
19737 3 1 8.545797 [4.129256194573827e-05, 0.1872873306274414, 4.638762402464636e-05, 0.00010757221753010526] Mladin Release From Road Atlanta Australia #39;s Mat Mladin completed a winning double at the penultimate round of this year #39;s American AMA Chevrolet Superbike Championship after taking
60726 2 0 8.437369 [0.5235446095466614, 4.4463453377829865e-05, 3.5171411582268775e-05, 8.480428368784487e-05] Suicide Bombings Kill 10 in Green Zone Insurgents hand-carried explosives into the most fortified section of Baghdad Thursday and detonated them within seconds of each other, killing 10 people and wounding 20.
28307 3 1 8.386065 [0.00018589739920571446, 0.42903241515159607, 2.5073826691368595e-05, 3.97983385482803e-05] Lightning Strike Injures 40 on Texas Field (AP) AP - About 40 players and coaches with the Grapeland High School football team in East Texas were injured, two of them critically, when lightning struck near their practice field Tuesday evening, authorities said.
[2]:
# save this to a file for further analysis
#losses_df.to_json("agnews_train_loss.json", orient="records", lines=True)

While using Pandas and Jupyter notebooks is useful for initial inspection, and programmatic analysis. If you want to quickly explore the examples, relabel them, and share them with other project members, Rubrix provides you with a straight-forward way for doing this. Let’s see how.

3. Log high loss examples into Rubrix#

Using the amazing Hugging Face Hub we’ve shared the resulting dataset, which you can find here and load directly using the datasets library

Now, we log the first 500 examples into a Rubrix dataset:

[ ]:
# if you have skipped the first two steps you can load the dataset here:
import pandas as pd
from datasets import load_dataset

dataset = load_dataset("dvilasuero/ag_news_training_set_losses", split='train')
losses_df = dataset.to_pandas()

ds = load_dataset('ag_news', split='test') # only for getting the label names
[7]:
import rubrix as rb
# creates a Text classification record for logging into Rubrix
def make_record(row):

    return rb.TextClassificationRecord(
        text=row.text,
        # this is the "gold" label in the original dataset
        annotation=[(ds.features['label'].names[row.label])],
        # this is the prediction together with its probability
        prediction=[(ds.features['label'].names[row.predicted_label], row.predicted_probas[row.predicted_label])],
        # metadata fields can be used for sorting and filtering, here we log the loss
        metadata={"loss": row.loss},
        # who makes the prediction
        prediction_agent="andi611/distilbert-base-uncased-ner-agnews",
        # source of the gold label
        annotation_agent="ag_news_benchmark"
    )
[8]:
# if you want to log the full dataset remove the indexing
top_losses = losses_df.sort_values("loss", ascending=False)[0:499]

# build Rubrix records
records = top_losses.apply(make_record, axis=1)
[ ]:
rb.log(records, name="ag_news_error_analysis")

4. Using Rubrix Webapp for inspection and relabeling#

In this step, we have a Rubrix Dataset available for exploration and annotation. A useful feature for this use case is Sorting. With Rubrix you can sort your examples by combining different fields, both from the standard fields (such as score) and custom fields (via the metadata fields). In this case, we’ve logged the loss so we can order our training examples by loss in descending order (showing higher loss examples first).

For preparing this tutorial, we have manually checked and relabelled the first 50 examples. Moreover, we’ve shared this re-annotated dataset in the Hugging Face Hub. In the next section, we show you how easy is to share Rubrix Datasets in the Hub.

5. Sharing the dataset in the Hugging Face Hub#

Let’s first load the re-annotated examples. Re-labelled examples are marked as annotated_by the user rubrix, which is the default user when launching Rubrix with Docker. We can retrieve only these records using the query param as follows:

[11]:
import rubrix as rb
dataset = rb.load("ag_news_error_analysis", query="annotated_by:rubrix").to_pandas()

# let's do some transformations before uploading the dataset
dataset['loss'] = dataset.metadata.transform(lambda r: r['loss'])
dataset = dataset.rename(columns={"annotation": "corrected_label"})

dataset.head()
[11]:
inputs prediction corrected_label prediction_agent annotation_agent multi_label explanation id metadata status event_timestamp metrics text loss
0 {'text': 'Top nuclear official briefs Majlis c... [(World, 0.1832696944)] World andi611/distilbert-base-uncased-ner-agnews rubrix False None 071a1014-71e7-41f4-83e4-553ba47610cf {'loss': 7.6656146049} Validated None {} Top nuclear official briefs Majlis committee T... 7.665615
1 {'text': 'Fischer Delivers Strong Message in S... [(World, 0.0695228428)] World andi611/distilbert-base-uncased-ner-agnews rubrix False None 07c8c4f6-3288-46f4-a618-3da4a537e605 {'loss': 7.9892320633} Validated None {} Fischer Delivers Strong Message in Syria Germa... 7.989232
2 {'text': 'The Politics of Time and Dispossessi... [(Sci/Tech, 0.100481838)] Sci/Tech andi611/distilbert-base-uncased-ner-agnews rubrix False None 0965a0d1-4886-432a-826a-58e99dfd9972 {'loss': 7.133708477} Validated None {} The Politics of Time and Dispossession Make a ... 7.133708
3 {'text': 'Hadash Party joins prisoners #39; st... [(World, 0.1749624908)] World andi611/distilbert-base-uncased-ner-agnews rubrix False None 09fc7065-a2c8-4041-adf8-34e029a7fde0 {'loss': 7.339015007} Validated None {} Hadash Party joins prisoners #39; strike for 2... 7.339015
4 {'text': 'China May Join \$10Bln Sakhalin-2 Ru... [(Business, 0.1370282918)] Business andi611/distilbert-base-uncased-ner-agnews rubrix False None 1ef97c49-2f0f-43be-9b28-80a291cb3b1d {'loss': 7.321100235} Validated None {} China May Join \$10Bln Sakhalin-2 Russia said ... 7.321100
[12]:
# let's add the original dataset labels to share them together with the corrected ones
# we sort by ascending loss our corrected dataset
dataset = dataset.sort_values("loss", ascending=False)

# we add original labels in string form
id2label = list(dataset.corrected_label.unique())
original_labels = [id2label[i] for i in top_losses[0:50].label.values]
dataset["original_label"] = original_labels

Now let’s transform this into a Dataset and define the features schema:

[13]:
from datasets import Dataset, Features, Value, ClassLabel

ds = dataset[['text', 'corrected_label', 'original_label']].to_dict(orient='list')

hf_ds = Dataset.from_dict(
    ds,
    features=Features({
        "text": Value("string"),
        "corrected_label": ClassLabel(names=list(dataset.corrected_label.unique())),
        "original_label": ClassLabel(names=list(dataset.corrected_label.unique()))
    })
)
[19]:
hf_ds.features
[19]:
{'text': Value(dtype='string', id=None),
 'corrected_label': ClassLabel(num_classes=4, names=['World', 'Business', 'Sports', 'Sci/Tech'], names_file=None, id=None),
 'original_label': ClassLabel(num_classes=4, names=['World', 'Business', 'Sports', 'Sci/Tech'], names_file=None, id=None)}

Uploading the dataset with the push_to_hub method is as easy as:

[ ]:
hf_ds.push_to_hub("Recognai/ag_news_corrected_labels")

Now the dataset is publicly available at the Hub!

Summary#

In this tutorial we say how you can leverage the model loss to find potential label errors in your training data set. The Rubrix web app makes it very convenient to sort your data by loss, inspect single records by eye, and allows you to easily correct label errors on the fly.

Next steps#

If you are interested in the topic of training data curation and denoising datasets, check out the tutorial for using Rubrix with cleanlab.

⭐ Rubrix Github repo to stay updated.

📚 Rubrix documentation for more guides and tutorials.

🙋‍♀️ Join the Rubrix community! A good place to start is the discussion forum.

[ ]: