🧱 Extending weak supervision workflows with sentence embeddings#

In this tutorial, we show how weak supervision workflows in Rubrix can be extended with sentence embeddings. We start from the weak supervision workflow presented in our Weak supervision with Rubrix tutorial and improve on its results by extending the coverage of its rules.

  • ✍️ We define rules and generate weak labels for the ag_news data set.

  • 🧱 We extend our weak labels with sentence embeddings from the Sentence Transformers library.

  • 📰 Finally, we use a label model to generate data for training a downstream model as a news classifier.

  • 🚀 We achieve a 4% improvement in accuracy over the original workflow simply by extending our weak labels.

Original and extended coverage of the weak labels

The two plots above show the coverage of the weak labels before and after extending them with embeddings. Each point corresponds to an example in the ag news test set. The color indicates the corresponding class of the example. Points in a transparent circle are covered by at least one rule.

Introduction#

Labelling functions normally have high precision, but low coverage. Only records that strictly match the conditions determined by a given function will be labelled, while other potential candidates will be left out.

Building on the findings of the Hazy Research group, we present a way to solve this problem by extending the weak labels produced by our labelling functions with sentence embeddings.

We extend the coverage of our labelling functions by giving unlabelled records the same label as their nearest labelled neighbor in the embedding space if the cosine similarity between them scores above a certain threshold.

We will show in this tutorial that, by adjusting these similarity thresholds and selecting proper sentence embeddings, we are able to significantly improve the accuracy of the downstream classifiers produced by our weak supervision workflows.

Detailed Workflow#

A typical workflow to perform weak supervision with sentence embeddings is:

  1. Create a Rubrix dataset with your raw dataset. If you have some labelled data, you can log it into the the same dataset.

  2. Define a set of weak labeling rules with the Rules definition mode in the UI.

  3. Create a WeakLabels object and apply the rules. You can load the rules from your dataset and add additional rules and labeling functions using Python. Typically, you’ll iterate between this step and step 2.

  4. Extend the WeakLabels object by giving sentence embeddings for each record ( the rows of the matrix ) and a similarity threshold for each rule ( the columns of the matrix ).

  5. Once you are satisfied with your extended weak labels, use the extended matrix of the WeakLabels instance with your library/method of choice to build a training set or even train a downstream text classification model. You can iterate between this step and step 4 to try several thresholds and embeddings possibilities until you achieve a satisfactory result.

This guide shows you an end-to-end example using Snorkel. You could alternatively use any other label model available in Rubrix. If you are interested in learning about other options, please check our weak supervision guide.

Setup#

Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the ⭐ Github repository.

If you have not installed and launched Rubrix yet, check the Setup and Installation guide.

For this tutorial we also need some third party libraries that can be installed via pip:

[ ]:
%pip install faiss-cpu sentence_transformers transformers datasets

The dataset#

Since this tutorial is an extension of our Weak supervision with Rubrix tutorial, we will also use the ag_news dataset, a well-known benchmark text classification models.

However, to guarantee a fair comparison, we will optimize the thresholds on a validation split, and leave the test split for the final evaluation.

[ ]:
from datasets import load_dataset

agnews = load_dataset("ag_news")

agnews_train, agnews_valid = agnews["train"].train_test_split(test_size=7600, seed=43).values()

1. Create a Rubrix dataset with unlabelled data and test data#

Just like in the first tutorial, let’s load a labelled and unlabelled set of records into Rubrix.

[ ]:
import rubrix as rb

# build our labelled records to evaluate our heuristic rules and optimize the thresholds
records = [
    rb.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "labelled"},
        annotation=agnews_valid.features["label"].int2str(record["label"]),
        id=f"valid_{idx}",
    )
    for idx, record in enumerate(agnews_valid)
]

# build our unlabelled records
records += [
    rb.TextClassificationRecord(
        text=record["text"],
        metadata={"split": "unlabelled"},
        id=f"train_{idx}",
    )
    for idx, record in enumerate(agnews_train)
]

# log the records to Rubrix
rb.log(records, name="news2")

After this step, you have a fully browsable dataset available that you can access via the Rubrix web app.

2. Defining rules#

We will use the same rules as found in the previous tutorial.

[ ]:
from rubrix.labeling.text_classification import Rule

# define queries and patterns for each category (using ES DSL)
queries = [
  (["money", "financ*", "dollar*"], "Business"),
  (["war", "gov*", "minister*", "conflict"], "World"),
  (["footbal*", "sport*", "game", "play*"], "Sports"),
  (["sci*", "techno*", "computer*", "software", "web"], "Sci/Tech")
]

# define rules
rules = [
    Rule(query=term, label=label)
    for terms,label in queries
    for term in terms
]

3. Building and analyzing weak labels#

After building weak labels from our rules, their summary reveals that our rules have, in total, 31% coverage while achieving 74% precision.

[ ]:
from rubrix.labeling.text_classification import WeakLabels

# apply the rules to the dataset to obtain the weak labels
weak_labels = WeakLabels(
    rules=rules,
    dataset="news2"
)
[5]:
weak_labels.summary()
[5]:
label coverage annotated_coverage overlaps conflicts correct incorrect precision
money {Business} 0.008242 0.008816 0.002450 0.001925 31 36 0.462687
financ* {Business} 0.019775 0.021184 0.005892 0.005183 115 46 0.714286
dollar* {Business} 0.016608 0.016974 0.003492 0.002850 98 31 0.759690
war {World} 0.011683 0.008816 0.003242 0.001367 44 23 0.656716
gov* {World} 0.045067 0.043158 0.010800 0.006225 156 172 0.475610
minister* {World} 0.030142 0.030263 0.007508 0.002825 207 23 0.900000
conflict {World} 0.003050 0.003684 0.001025 0.000092 20 8 0.714286
footbal* {Sports} 0.013050 0.015132 0.004875 0.000408 105 10 0.913043
sport* {Sports} 0.021183 0.021711 0.007033 0.001225 146 19 0.884848
game {Sports} 0.038950 0.043026 0.014067 0.002375 253 74 0.773700
play* {Sports} 0.052608 0.057632 0.016767 0.004992 312 126 0.712329
sci* {Sci/Tech} 0.016433 0.015658 0.002742 0.001275 101 18 0.848739
techno* {Sci/Tech} 0.027150 0.028816 0.008325 0.003108 153 66 0.698630
computer* {Sci/Tech} 0.027275 0.026447 0.011100 0.004483 167 34 0.830846
software {Sci/Tech} 0.030283 0.032763 0.009625 0.003308 202 47 0.811245
web {Sci/Tech} 0.015508 0.016316 0.004100 0.001608 111 13 0.895161
total {Sci/Tech, World, Sports, Business} 0.317375 0.327895 0.053408 0.019425 2221 746 0.748568

In the next steps, we will try to extend our weak labels matrix through sentence embeddings. In this way, we will increase the coverage of our rules, while maintaining an acceptable precision.

4. Using the weak labels#

Label model with Snorkel#

Snorkel’s label model is by far the most popular option for using weak supervision, and Rubrix provides built-in support for it. Here we fit our weak labels to the Snorkel label model, and then we check the performance on the records that have been covered by the rules.

[6]:
from rubrix.labeling.text_classification import Snorkel

# create the the Snorkel label model
label_model = Snorkel(weak_labels)

# fit the model, for the learning rate and epochs we ran a quick grid search
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

# evaluate the label model
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

    Business       0.73      0.41      0.53       493
      Sports       0.77      0.97      0.86       703
       World       0.69      0.83      0.75       462
    Sci/Tech       0.80      0.74      0.77       833

    accuracy                           0.76      2491
   macro avg       0.75      0.74      0.73      2491
weighted avg       0.76      0.76      0.74      2491

5. Extending the weak labels#

Let’s extend our weak labels and see how that impacts the evaluation of the Snorkel label model.

Generate sentence embeddings#

Let’s generate sentence embeddings for each record of our weak labels matrix. Best results will be achieved through powerful general-purpose pretrained embeddings, or by embeddings especifically pretrained for the domain of the task at hand.

Here we choose the all-mpnet-base-v2 embeddings from the well-known Sentence Transformers library. Rubrix allows us to experiment with embeddings from any source, as long as they are provided to the weak labels matrix as a two-dimensional array.

For instance, instead of Sentence Transformers, we could have utilized GPT-3 similarity embeddings from the OpenAI Embeddings API, or text embeddings from the Tensorflow Hub, or we could even have trained our own embeddings from scratch.

[ ]:
from sentence_transformers import SentenceTransformer
from tqdm.auto import tqdm

# instantiate the model for the sentence embeddings
# we strongly recommend using a GPU for the computation of the embeddings
model = SentenceTransformer('all-mpnet-base-v2', device='cuda')

# compute the embeddings and store them in a list
embeddings = []
for rec in tqdm(weak_labels.records()):
    embeddings.append(model.encode(rec.text))

Set the thresholds#

We start by making an educated guess on which thresholds will work for this particular weak labels matrix. We set the thresholds for all rules to 0.60. This means that, for each rule, the label of a record will be extended to its nearest unlabelled neighbor if their cosine similarity is above this value.

[ ]:
thresholds = [0.6] * len(rules)

Extend the weak labels matrix#

We call the extend_matrix method by providing the thresholds and the sentence embeddings.

[ ]:
weak_labels.extend_matrix(thresholds, embeddings)

With the weak label matrix extended, we can check that our coverage went up significantly (from 0.32 to 0.79).

[10]:
weak_labels.summary()
[10]:
label coverage annotated_coverage overlaps conflicts correct incorrect precision
money {Business} 0.079342 0.083158 0.068042 0.061908 342 290 0.541139
financ* {Business} 0.122692 0.130526 0.096475 0.087525 596 396 0.600806
dollar* {Business} 0.104917 0.110789 0.084467 0.077458 552 290 0.655582
war {World} 0.083958 0.081053 0.069425 0.050867 353 263 0.573052
gov* {World} 0.218817 0.216447 0.157083 0.118283 843 802 0.512462
minister* {World} 0.124367 0.121974 0.091817 0.055133 752 175 0.811219
conflict {World} 0.038975 0.036579 0.035533 0.023192 199 79 0.715827
footbal* {Sports} 0.040083 0.043289 0.029475 0.007708 308 21 0.936170
sport* {Sports} 0.114292 0.111053 0.084200 0.031992 745 99 0.882701
game {Sports} 0.125608 0.130526 0.094600 0.035858 745 247 0.751008
play* {Sports} 0.205158 0.206053 0.156583 0.091692 900 666 0.574713
sci* {Sci/Tech} 0.058375 0.056053 0.035192 0.028675 257 169 0.603286
techno* {Sci/Tech} 0.117850 0.124079 0.092108 0.074517 558 385 0.591729
computer* {Sci/Tech} 0.100017 0.097632 0.079608 0.060317 553 189 0.745283
software {Sci/Tech} 0.088967 0.088816 0.065025 0.046550 517 158 0.765926
web {Sci/Tech} 0.084800 0.081579 0.067950 0.054550 415 205 0.669355
total {Sci/Tech, Sports, Business, World} 0.793542 0.793553 0.392908 0.228892 8635 4434 0.660724

We also see that the average precision of our rules went down (from 0.75 to 0.66). This drop, however, can be partially compensated by our label model. If we fit our weak labels to a Snorkel label model again, we can see that the support went up significantly, as expected, while the drop in accuracy is minor.

[12]:
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

    Sci/Tech       0.76      0.74      0.75      1636
       World       0.67      0.86      0.76      1421
      Sports       0.79      0.96      0.87      1544
    Business       0.78      0.39      0.52      1430

    accuracy                           0.74      6031
   macro avg       0.75      0.74      0.72      6031
weighted avg       0.75      0.74      0.73      6031

You can have a look at the Appendix to have a detailed explanation about how the weak label matrix is extended under the hood.

Instead of using generic fixed thresholds, we recommend to optimize them in some way to get the highest performance gains. Our optimization described in detail in the Appendix yielded following thresholds:

[ ]:
optimized_thresholds = [0.4, 0.4, 0.6, 0.4, 0.5, 0.8, 1., 0.4, 0.4, 0.5, 0.6, 0.4, 0.4, 0.6, 0.6, 0.8]

Each call to extend_matrix with thresholds and embeddings will build a faiss index that will be cached inside the weak labels object.

If we do not provide embeddings in our next calls to extend_matrix, this index will be reutilized, and a new extended matrix will replace the current extended matrix. So extending the matrix with new threshold is very cheap.

[28]:
weak_labels.extend_matrix(optimized_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(label_model.score(output_str=True))
              precision    recall  f1-score   support

    Sci/Tech       0.74      0.67      0.70      1906
       World       0.78      0.64      0.70      1789
      Sports       0.82      0.90      0.86      1875
    Business       0.60      0.69      0.64      1877

    accuracy                           0.73      7447
   macro avg       0.73      0.73      0.73      7447
weighted avg       0.73      0.73      0.73      7447

The optimized thresholds seem to further reduce the accuracy of the label model, but also increase the coverage significantly.

6. Training a downstream model#

Now we will train the same dowstream model as in the previous tutorial, but on the data produced by a label model from our extended weak labels.

Let us first define a helper function that is basically a copy&paste from the previous tutorial.

[ ]:
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import Pipeline
from sklearn import metrics


def train_and_evaluate_downstream_model(label_model):
    """
    Train a downstream model with the predictions of a label model and
    evauate it with the test split of the ag news dataset
    """
    # get records with the predictions from the label model
    records = label_model.predict()

    # turn str labels into integers
    label2int = label_model.weak_labels.label2int

    # extract training data
    X_train = [rec.text for rec in records]
    y_train = [label2int[rec.prediction[0][0]] for rec in records]

    # define our final classifier
    classifier = Pipeline([
        ('vect', CountVectorizer()),
        ('clf', MultinomialNB())
    ])

    # fit the classifier
    classifier.fit(
        X=X_train,
        y=y_train,
    )

    # extract text and labels
    X_test = [rec["text"] for rec in agnews["test"]]
    y_test = [label2int[agnews["test"].features["label"].int2str(rec["label"])] for rec in agnews["test"]]

    # get predictions for the test set
    predicted = classifier.predict(X_test)

    return metrics.classification_report(y_test, predicted, target_names=[k for k in label2int.keys() if k])

Now let’s see how our downstream model compares with the original model from the previous tutorial. Remember we achieved an accuracy of around 82%.

[35]:
print(train_and_evaluate_downstream_model(label_model))
              precision    recall  f1-score   support

    Sci/Tech       0.85      0.82      0.83      1900
       World       0.90      0.84      0.87      1900
      Sports       0.90      0.97      0.93      1900
    Business       0.81      0.82      0.81      1900

    accuracy                           0.86      7600
   macro avg       0.86      0.86      0.86      7600
weighted avg       0.86      0.86      0.86      7600

Now, with our extended weak label matrix, we were able to achieve an accuracy of 86%, a 4% improvement over our original approach.

Summary#

In this tutorial you have seen how to improve your weak supervision workflows in Rubrix using word embeddings. With very small changes to the original workflow, we were able to significantly increase the accuracy of our downstream models. This shows that Rubrix can greatly reduce the amount of effort that human annotators need to put into writing rules before they can achieve exceptional results.

Next steps#

⭐ Rubrix Github repo to stay updated.

📚 Rubrix documentation for more guides and tutorials.

🙋‍♀️ Join the Rubrix community! A good place to start is the discussion forum.

Appendix: Visualize changes#

Let’s visualize how the weak labels matrix is being extended in a single row.

[ ]:
import pandas as pd

def get_transitions(weak_labels, idx):
    transitions = list(list(zip(row[0], row[1])) for row in zip(weak_labels._matrix, weak_labels._extended_matrix))
    transitions = transitions[idx]
    label_dict = weak_labels.int2label
    rule_labels = weak_labels.summary().reset_index()['index'].values.tolist()[:-1]
    transitions_df = []
    for rule_idx, rule in enumerate(rule_labels):
        old_label = transitions[rule_idx][0]
        new_label = transitions[rule_idx][1]
        transitions_df.append({
            "rule": rule,
            "old label": label_dict[old_label],
            "new label": label_dict[new_label],
        })
    transitions_df = pd.DataFrame(transitions_df)
    text = weak_labels.records()[idx].text
    return transitions_df, text

transitions, text = get_transitions(weak_labels, 15)

By reading the selected record, we can clearly notice that it is a news article about world politics, and therefore should be classified as World.

[79]:
text
[79]:
'Israel  #39;determined to complete Gaza plan #39; Israel is determined to go ahead with its unilateral withdrawal from the Gaza Strip - regardless of the death of Yasser Arafat and even if settlers resist - a top Israeli general who helped design the plan says.'

Let’s put side by side the row of the original weak labels matrix for this record ( the "old label" row ) and the same row after extension ( the "new label" row ).

We see that this news article was not labelled in the original matrix by any of our rules.

However, it was the nearest unlabelled neighbor of two Business articles, matched by the rules financ* and dollar*, and its similarity with them scored above our selected thresholds. The same happened for two World articles, matched by the rules war and minister*, and for a Sci/Tech article matched by the rule sci*.

[80]:
transitions.transpose()
[80]:
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
rule money financ* dollar* war gov* minister* conflict footbal* sport* game play* sci* techno* computer* software web
old label None None None None None None None None None None None None None None None None
new label None Business Business World None World None None None None None Sci/Tech None None None None

Appendix: Optimizing the thresholds#

Each call to extend_matrix with thresholds and embeddings will build a faiss index that will be cached inside the weak labels object.

If we do not provide embeddings in our next calls to extend_matrix, this index will be reutilized, and a new extended matrix will replace the current extended matrix. This new matrix is an extension of the original weak labels matrix made according to our new similarity thresholds.

[ ]:
# Let's try to set all thresholds to 0.8 instead of 0.6.
thresholds = [0.8] * len(rules)

# As we have already generated the index in our first call, we just need to provide the thresholds.
weak_labels.extend_matrix(thresholds)

There are a few different approaches to find the best similarity thresholds for extending a weak labels matrix: we will list them from the least to the most computationally expensive.

1. Block the extension of low overlap rules#

After setting all similarity thresholds to a reasonable value, a good way to optimize the similarity thresholds on an individual level is to block the extension of rules with low overlap, as they are more likely to produce inaccurate results after extension.

[111]:
summary = weak_labels.summary(normalize_by_coverage=True).reset_index().head(len(rules))
summary = summary.rename(columns={"index":"rule"})
summary = summary.sort_values(by="overlaps", ascending=True)[["rule", "overlaps"]]
summary = summary.reset_index()
summary
[111]:
index rule overlaps
0 4 gov* 0.239645
1 15 web 0.264374
2 14 software 0.317832
3 10 play* 0.318707
4 6 conflict 0.336066
5 9 game 0.361147
6 8 sport* 0.393875
7 11 sci* 0.483512
8 5 minister* 0.514205
9 7 footbal* 0.567152
10 13 computer* 0.703716
11 12 techno* 0.710083
12 1 financ* 0.737078
13 2 dollar* 0.742097
14 3 war 0.744417
15 0 money 0.788047
[112]:
thresholds = [0.6] * len(rules)

# Let's block the extension of the top 5 rules with the least overlap.
turn_off_index = summary['index'][0:6]

# We block the extension of a rule by setting its similarity threshold to 1.0.
for rule_index in turn_off_index:
    thresholds[rule_index] = 1.0

weak_labels.extend_matrix(thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))
              precision    recall  f1-score   support

    Sci/Tech       0.81      0.84      0.82      1900
       World       0.90      0.87      0.88      1900
      Sports       0.91      0.96      0.93      1900
    Business       0.81      0.76      0.78      1900

    accuracy                           0.86      7600
   macro avg       0.85      0.86      0.85      7600
weighted avg       0.85      0.86      0.85      7600

2. Brute force: Grid search over the label model#

In this approach, we set all thresholds to an initial value, and then grid search for the best value for each one of them individually. Then we optimize for the harmonic mean between the coverage and the accuracy of the label model on the development set. This will ensure that we choose the thresholds with the best trade-off between both metrics.

We arrive at the same improvement as the previous approach, with a final accuracy of 86% over the test set.

[ ]:
def train_eval_labelmodel(ths):
    weak_labels.extend_matrix(ths)

    label_model = Snorkel(weak_labels)
    label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

    metrics = label_model.score()
    acc, sup, n = metrics["accuracy"], metrics["macro avg"]["support"], len(weak_labels.annotation())
    coverage = sup / n
    return 2 * acc * coverage / ( acc + coverage )
[ ]:
import copy
from tqdm.auto import tqdm

ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

best_thresholds = [1.0] * n_ths
best_acc = 0.0
for i in tqdm(range(n_ths), total=n_ths):
    thresholds = best_thresholds.copy()
    for threshold in ths_range:
        thresholds[i] = threshold
        acc = train_eval_labelmodel(thresholds)
        if acc > best_acc:
            best_acc = acc
            best_thresholds = thresholds.copy()
[121]:
np.array(best_thresholds)
[121]:
array([0.4, 0.4, 0.6, 0.4, 0.5, 0.8, 1. , 0.4, 0.4, 0.5, 0.6, 0.4, 0.4,
       0.6, 0.6, 0.8])
[122]:
weak_labels.extend_matrix(best_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))
              precision    recall  f1-score   support

    Sci/Tech       0.83      0.83      0.83      1900
       World       0.89      0.85      0.87      1900
      Sports       0.90      0.97      0.93      1900
    Business       0.82      0.79      0.80      1900

    accuracy                           0.86      7600
   macro avg       0.86      0.86      0.86      7600
weighted avg       0.86      0.86      0.86      7600

3. Brute force: Grid search over the downstream model#

Here again we set all thresholds to an initial value and grid search for the best value for each individual threshold, but now we optimize for the accuracy of the downstream model on the development set. We arrive at a final accuracy of 85% on the test set, which is slightly less than what we achieved through the previous approaches.

[126]:
# retrieve records with annotations
test_ds = weak_labels.records(has_annotation=True)

# extract text and labels
X_test_for_grid_search = [rec.text for rec in test_ds]
y_test_for_grid_search = [weak_labels.label2int[rec.annotation] for rec in test_ds]

def train_eval_downstream(ths):
    weak_labels.extend_matrix(ths)

    label_model = Snorkel(weak_labels)
    label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)

    records = label_model.predict()

    X_train = [rec.text for rec in records]
    y_train = [weak_labels.label2int[rec.prediction[0][0]] for rec in records]

    classifier = Pipeline([
        ('vect', CountVectorizer()),
        ('clf', MultinomialNB())
    ])

    classifier.fit(
        X=X_train,
        y=y_train,
    )

    accuracy = classifier.score(
        X=X_test_for_grid_search,
        y=y_test_for_grid_search,
    )

    return accuracy
[ ]:
from copy import copy
from tqdm.auto import tqdm

best_thresholds, best_acc = [1.0] * len(weak_labels.rules), 0
ths_range = np.arange(1, 0.3, -0.1)
n_ths = len(weak_labels.rules)

for i in tqdm(range(n_ths), total=n_ths):
    thresholds = best_thresholds.copy()
    for threshold in ths_range:
        thresholds[i] = threshold
        acc = train_eval_downstream(thresholds)
        if acc > best_acc:
            best_acc = acc
            best_thresholds = thresholds.copy()
[128]:
np.array(best_thresholds)
[128]:
array([0.6, 0.7, 0.9, 1. , 1. , 0.8, 0.7, 1. , 0.6, 1. , 1. , 0.7, 0.8,
       0.7, 0.9, 0.8])
[129]:
weak_labels.extend_matrix(best_thresholds)
label_model = Snorkel(weak_labels)
label_model.fit(lr=0.002, n_epochs=10, progress_bar=False)
print(train_and_evaluate_downstream_model(label_model))
              precision    recall  f1-score   support

    Sci/Tech       0.81      0.82      0.82      1900
       World       0.89      0.85      0.87      1900
      Sports       0.88      0.98      0.93      1900
    Business       0.82      0.75      0.78      1900

    accuracy                           0.85      7600
   macro avg       0.85      0.85      0.85      7600
weighted avg       0.85      0.85      0.85      7600

Tips on threshold optimization#

Grid search with large downstream models, such as transformers, can be very expensive. In this scenario, we can consider to optimize only a subset of the thresholds, or to optimize all thresholds on a small sample of the development set.

Although in this tutorial we perform grid search sequentially, there is no impediment to speed it up by performing it in parallel, as long as we make deep copies of the weak labels object for each process or thread.

Appendix: Plot extension#

[ ]:
import umap
import matplotlib.pyplot as plt

umap_data = umap.UMAP(n_neighbors=15, n_components=2, min_dist=0.0, metric='cosine').fit_transform(embeddings)

df = rb.DatasetForTextClassification(weak_labels.records()).to_pandas()
df["x"], df["y"] = umap_data[:, 0], umap_data[:, 1]
df["wl"] = [em for em in weak_labels._matrix]
df["wl_ext"] = [em for em in weak_labels._extended_matrix]

cov_idx = df["wl"].map(lambda x: x.sum() != -16)
cov_ext_idx = df["wl_ext"].map(lambda x: x.sum() != -16)
test_idx = ~(df.annotation.isna())

df_test = df[test_idx]
df_cov, df_cov_ext = df[cov_idx & test_idx], df[cov_ext_idx & test_idx]

label2int = {label: i for i, label in enumerate(df_test.annotation.value_counts().index)}

fig, ax = plt.subplots(1, 2, figsize=(13, 6), )

ax[0].scatter(df_test.x, df_test.y, c=df_test.annotation.map(lambda x: label2int[x]), s=10)
ax[0].scatter(df_cov.x, df_cov.y, c=df_cov.annotation.map(lambda x: label2int[x]), s=100, alpha=0.2)

scatter = ax[1].scatter(df_test.x, df_test.y, c=df_test.annotation.map(lambda x: label2int[x]), s=10)
ax[1].scatter(df_cov_ext.x, df_cov_ext.y, c=df_cov_ext.annotation.map(lambda x: label2int[x]), s=100, alpha=0.2)

ax[0].set_title("Original",{"fontsize": "xx-large"})
ax[0].set_xticks([]), ax[0].set_yticks([])

ax[1].set_title("Extended",{"fontsize": "xx-large"})
ax[1].set_xticks([]), ax[1].set_yticks([])

labels = list(scatter.legend_elements())
labels[1] = list(label2int.keys())
legend1 = ax[0].legend(*labels, loc="lower right", fontsize="xx-large")
ax[0].add_artist(legend1)

fig.tight_layout()
plt.savefig("extend_weak_labels.png", facecolor='white', transparent=False)