📰 Building a news classifier with weak supervision

TL;DR

  1. We build a news classifier using rules and weak supervision

  2. For this example, we use the AG News dataset but you can follow this process to programatically label any dataset.

  3. The train split without labels is used to build a training set with rules, Rubrix and Snorkel’s Label model.

  4. The test set is used for evaluating our weak labels, label model and downstream news classifier.

  5. We achieve 0.81 macro avg. f1-score without using a single example from the original dataset and using a pretty lightweight model (scikit-learn’s MultinomialNB).

The following diagram shows the overall process for using Weak supervision with Rubrix:

Labeling workflow

Setup Rubrix

Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the ⭐ Github repository.

You can install Rubrix on your local machine, on a server, or using a cloud provider. If you have not installed and launched Rubrix, check the Setup and Installation guide.

Once installed, you only need to import Rubrix and some other libraries we’ll be using for this tutorial:

[2]:
import rubrix as rb
from rubrix.labeling.text_classification import *

from datasets import load_dataset
import pandas as pd

1. Load test and unlabelled datasets into Rubrix

Let’s load the test split from the ag_news dataset, which we’ll be using for testing our label and downstream models.

[ ]:
dataset = load_dataset("ag_news", split="test")

labels = dataset.features["label"].names

records = [
    rb.TextClassificationRecord(
        inputs=record["text"],
        metadata={"split": "test"},
        annotation=labels[record["label"]]
    )
    for record in dataset
]

rb.log(records, name="news")

Let’s load the train split from the ag_news dataset without labels. Our goal will be to programmatically build a training set using rules and weak supervision.

[ ]:
dataset = load_dataset("ag_news", split="train")

records = [
    rb.TextClassificationRecord(
        inputs=record["text"],
        metadata={"split": "unlabelled"},
    )
    for record in dataset
]

rb.log(records, name="news")

The result of the above is the following dataset in Rubrix with 127.600 records (120.000 unlabelled and 7.600 for testing).

You can use the webapp for finding good rules for programmatic labeling.

News dataste

2. Create rules and weak labels

Let’s define some rules for each category, here you can use the expressive power of Elasticsearch’s query string DSL.

[3]:
# Define queries and patterns for each category (using ES DSL)
queries = [
  (["money", "financ*", "dollar*"], "Business"),
  (["war", "gov*", "minister*", "conflict"], "World"),
  (["footbal*", "sport*", "game", "play*"], "Sports"),
  (["sci*", "techno*", "computer*", "software", "web"], "Sci/Tech")
]

rules = [
    Rule(query=term, label=label)
    for terms,label in queries
    for term in terms
]
[ ]:
weak_labels = WeakLabels(
    rules=rules,
    dataset="news"
)

It takes around 24 seconds to apply the rules and get the weak labels for the 127.600 examples

Typically, you want to iterate on the rules and check their statistics. For this, you can use weak_labels.summary method:

[5]:
weak_labels.summary()
[5]:
polarity coverage overlaps conflicts correct incorrect precision
money {Business} 0.008276 0.002437 0.001936 30 37 0.447761
financ* {Business} 0.019655 0.005893 0.005188 80 55 0.592593
dollar* {Business} 0.016591 0.003542 0.002908 87 37 0.701613
war {World} 0.011779 0.003213 0.001348 75 26 0.742574
gov* {World} 0.045078 0.010878 0.006270 170 174 0.494186
minister* {World} 0.030031 0.007531 0.002821 193 22 0.897674
conflict {World} 0.003041 0.001003 0.000102 18 4 0.818182
footbal* {Sports} 0.013166 0.004945 0.000439 107 7 0.938596
sport* {Sports} 0.021191 0.007045 0.001223 139 23 0.858025
game {Sports} 0.038879 0.014083 0.002375 216 71 0.752613
play* {Sports} 0.052453 0.016889 0.005063 268 112 0.705263
sci* {Sci/Tech} 0.016552 0.002735 0.001309 114 26 0.814286
techno* {Sci/Tech} 0.027218 0.008433 0.003174 155 60 0.720930
computer* {Sci/Tech} 0.027320 0.011058 0.004459 159 54 0.746479
software {Sci/Tech} 0.030243 0.009655 0.003346 184 41 0.817778
web {Sci/Tech} 0.015376 0.004067 0.001607 76 25 0.752475
total {Sci/Tech, Business, Sports, World} 0.317022 0.053582 0.019561 2071 774 0.727944

From the above, we see that our rules cover around 30% of the original training set with an average precision of 0.72, our hope is that the label and downstream models will improve both the recall and the precision of the final classifier.

3. Denoise weak labels with Snorkel’s Label Model

The goal at this step is to denoise the weak labels we’ve just created using rules. There are several approaches to this problem using different statistical methods.

In this tutorial, we’re going to use Snorkel but you can actually use any other Label model or weak supervision method (see the Weak supervision guide for more details).

For convenience, Rubrix defines a simple wrapper over Snorkel’s Label Model so it’s easier to use with Rubrix weak labels and datasets:

[6]:
# If Snorkel is not installed on your machine !pip install snorkel

label_model = Snorkel(weak_labels)

# Fit Label Model
label_model.fit()

# Test with labeled test set
label_model.score()
WARNING:rubrix.labeling.text_classification.label_models:Metrics are only calculated over non-abstained predictions!
[6]:
{'accuracy': 0.7448246725813266}

3. Prepare our training set

Now, we already have a “denoised” training set, which we can prepare for training a downstream model.

The label model predict returns TextClassificationRecord objects with the predictions from the label model.

We can either refine and review these records using the Rubrix Webapp, use them as is, or filter them by score for example.

In this case, we assume the predictions are precise enough and use them without any revision.

Our training set has ~38.000 records, which corresponds to all records where the label model has not abstained.

[20]:
records = label_model.predict()

# build a simple dataframe with text and the prediction with the highest score
df_train = pd.DataFrame([
    {"text": record.inputs["text"], "label": label_model.weak_labels.label2int[record.prediction[0][0]]}
    for record in records
])
df_train
[20]:
text label
0 Jan Baan launches Web services firm com Septem... 0
1 Molson Indy Vancouver gets black flag quot;Th... 1
2 The football gods were on our side #39; Jason ... 1
3 Jags get offense clicking in second half Fred ... 1
4 Puzzle Over Low Galaxy Count Scientists from t... 0
... ... ...
38080 Football legend Maradona rushed to hospital Fo... 1
38081 Head of British charity expelled from Sudan Th... 3
38082 From SANs to SATAs, storage vendors continue p... 0
38083 Billups Sits Out Because of Ankle Sprain (AP) ... 1
38084 Judge Rules for Oracle in PeopleSoft Bid (Reut... 0

38085 rows × 2 columns

[19]:
# for the test set, we can retrieve the records with validated annotations (the original ag_news test set)
df_test = rb.load("news", query="status:Validated")

df_test['text'] = df_test.inputs.transform(lambda r: r['text'])
df_test['annotation'] = df_test['annotation'].apply(
    lambda r:label_model.weak_labels.label2int[r]
)

4. Train a downstream model with scikit-learn

Now, let’s train our final model using scikit-learn

[ ]:
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline

classifier = Pipeline([
    ('vect', CountVectorizer()),
    ('clf', MultinomialNB())
])

classifier.fit(
    X=df_train.text.tolist(),
    y=df_train.label.values
)
[18]:
accuracy = classifier.score(
    X=df_test.text.tolist(),
    y=label_model.weak_labels.annotation()
)

f"Test accuracy: {accuracy}"
[18]:
'Test accuracy: 0.8177631578947369'

Not too bad!

We have achieved around 0.81 accuracy without even using a single example from the original ag_news train set and with a small set of rules (less than 30). Also, we’ve largely improved over the 0.74 accuracy of our Label Model.

Finally, let’s take a look at more detailed metrics:

[82]:
from sklearn import metrics

labels = list(label_model.weak_labels.label2int.keys())[1:] # removes "abstain" label
predicted = classifier.predict(df_test.text.tolist())

print(metrics.classification_report(label_model.weak_labels.annotation(), predicted, target_names=labels))
              precision    recall  f1-score   support

    Sci/Tech       0.76      0.83      0.80      1900
      Sports       0.86      0.98      0.91      1900
    Business       0.89      0.56      0.69      1900
       World       0.79      0.89      0.84      1900

    accuracy                           0.82      7600
   macro avg       0.82      0.82      0.81      7600
weighted avg       0.82      0.82      0.81      7600

Next steps

If you are interested in the topic of weak supervision check the Weak supervision guide.

🙋‍♀️ Join the Rubrix community on Slack

⭐ Rubrix Github repo to stay updated.