🐭 Weakly supervised NER with skweak#

This tutorial will walk you through the process of using Rubrix to improve weak supervision and data programming workflows with the skweak library.

  • Using Rubrix, skweak and spaCy, we define heuristic rules for the CoNLL 2003 dataset.

  • We then log the labelled documents to Rubrix and visualize the results via its web app.

  • After aggregating the noisy labels, we fine-tune and evaluate a spaCy NER model.

Introduction#

Our goal is to show you how you can incorporate Rubrix into data programming workflows to programatically build training data with a human-in-the-loop approach. We will use the skweak library.

What is weak supervision? and skweak?#

Weak supervision is a branch of machine learning based on getting lower quality labels more efficiently. We can achieve this by using skweak, a library for programmatically building and managing training datasets without manual labeling.

This tutorial#

In this tutorial, we will show you how to extend weak supervision workflows in skweak with Rubrix.

We will take records from the CoNLL 2003 dataset and build our own annotations with skweak. Then we are going to evaluate NER models trained on our annotations on the development set of CoNLL 2003.

Setup#

Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the ⭐ Github repository.

If you have not installed and launched Rubrix yet, check the Setup and Installation guide.

For this tutorial we also need some third party libraries that can be installed via pip:

[ ]:
%pip install -U spacy -qqq
%pip install --user git+https://github.com/NorskRegnesentral/skweak -qqq
!python -m spacy download en_core_web_lg

1. Log the dataset into Rubrix#

Rubrix allows you to log and track data for different NLP tasks (such as Token Classification or Text Classification).

In this tutorial, we will use the English portion of the CoNLL 2003 dataset, a standard Named Entity Recognition benchmark.

The dataset#

We will use skweak’s data programming methods to annotate our training set, with the help of Rubrix for analyzing and reviewing the data. We will then train a model on this training set.

Although the gold labels for the training set of CoNLL 2003 are already known, we will purposefully ignore them, as our goal in this tutorial is to build our own annotations and see how well they perform on the development set.

And to simplify our tutorial, only the ORG label will be taken into account, both in training and evaluation. Other labels present on the dataset will be ignored ( LOC, PER and MISC ).

We will load the CoNLL 2003 dataset with the help of the datasets library.

[ ]:
from datasets import load_dataset

conll2003 = load_dataset("conll2003")

Logging#

Before we log the development data, we define a utility function that will convert our NER tags from the datasets format to Rubrix annotations.

[ ]:
from spacy.tokens import Doc
from spacy.vocab import Vocab
from spacy.training.iob_utils import iob_to_biluo, biluo_tags_to_offsets

def tags_to_entities(row):
    doc = Doc(Vocab(), words=row["tokens"])
    ner_tags = conll2003["train"].features["ner_tags"].feature.int2str(row["ner_tags"])
    offsets = biluo_tags_to_offsets(doc, iob_to_biluo(ner_tags))

    return [(entity, start, stop) for start, stop, entity in offsets]

We define a generator that will yield each row of our dataset as a TokenClassificationRecord object.

[ ]:
import rubrix as rb
from tqdm.auto import tqdm

def dataset_to_records(dataset):
    for row in tqdm(dataset):

        text = " ".join(row["tokens"])

        # seems like we have "empty" rows
        if not text.strip():
            continue

        yield rb.TokenClassificationRecord(
                text=text,
                tokens=row["tokens"],
                annotation=tags_to_entities(row)
            )

Now we upload our records through the Rubrix API for a first inspection. Although we are uploading all annotations, we can filter for ORG entities on the web app.

[ ]:
rb.log(dataset_to_records(conll2003["validation"]), "conll2003_dev")

2. Use Rubrix to write skweak heuristic rules#

Heuristic rules in skweak are applied through labelling functions. Each of these functions must yield the start and end index of the annotated span followed by its assigned label.

Annotating a specific case: sports teams#

We define our first heuristic rules to match records related to sports teams.

After inspecting the dataset on Rubrix, we are able to notice that several records start with the name of a sports team followed by its game scores.

We also notice that other group of records feature the names of two sports teams and their scores after a match against each other.

We write two rules to capture these sports team names as ORG entities.

Visualization of sports team entities in Rubrix
[ ]:
def sports_results_detector(doc):
    """
    Captures a sports team name followed by its game scores.
    Labels the sports team as an ORG.
    Examples:
        Loznica 4 2 0 2 7 4 6
        Berwick 3 0 0 3 1 14 0
    """
    # Label first word as ORG if it is followed only by numbers and punctuation.
    if len(doc) < 2:
        return
    has_digits = False
    for idx, token in enumerate(doc):
        if not idx and token.text.isalpha() and token.text.istitle():
            continue
        elif idx and token.text.isdigit():
            continue
        else:
            break
    else:
        yield 0, 1, "ORG"

def sports_match_detector(doc):
    """
    Captures a sports match.
    Labels both sports teams as ORG.
    Examples:
        Bournemouth 1 Peterborough 2
        Dumbarton 1 Brechin 1
    """
    if len(doc) != 4:
        return

    if (
        doc[0].text.istitle()
        and doc[1].text.isdigit()
        and doc[2].text.istitle()
        and doc[3].text.isdigit()
    ):
        yield 0, 1, "ORG"
        yield 2, 3, "ORG"

Let’s encapsulate our heuristic rules as labelling functions.

Labelling functions are defined as FunctionAnnotator objects, and multiple functions can be grouped inside a single CombinedAnnotator.

[ ]:
from skweak.heuristics import FunctionAnnotator

sports_results_annotator = FunctionAnnotator("sports_results", sports_results_detector)
sports_match_annotator = FunctionAnnotator("sports_match", sports_match_detector)

Although it is possible to call each one of these annotators independently, if we are going to call several annotators at the same time, it is more convenient to group them under a single combined annotator.

We add each one of them to our combined annotator through a add_annotator method.

[ ]:
from skweak.base import CombinedAnnotator

rule_based_annotator = CombinedAnnotator()

for annotator in [sports_results_annotator, sports_match_annotator]:
    rule_based_annotator.add_annotator(annotator)

Annotating with generic rules#

We can also write rules that are a litle bit more generic.

For instance, organizations often are presented as a series of capitalized words that either start or end with a certain keyword. We write a generator called title_detector to capture them.

Visualization of generic rules in Rubrix
[ ]:
def title_detector(doc, keyword=None, label="ORG", reverse=False):
    """
    Captures a sequence of capitalized words that either start or end with a certain keyword.
    Labels the sequence, including the keyword, with the ORG label.
    Examples:

        The following examples start with the keyword "U.S."":
        - U.S. Treasury Department
        - U.S. Treasuries
        - U.S. Agriculture Department

        The following examples end with the keyword "Corp":
        - First of Michigan Corp
        - Caltex Petroleum Corp
        - Kia Motor Corp
    """
    start = None
    end = None

    if reverse:
        len_doc = len(doc)
        doc = reversed(doc)

    for idx, token in enumerate(doc):
        if token.text == keyword:
            start = idx
        elif start:
            if token.text.istitle():
                continue
            else:
                if start + 2 != idx:
                    end = idx

                    if reverse:
                        start, end = len_doc - end, len_doc - start

                    yield start, end, label

                start = None
                end = None

We take a small list of keywords that appear at the start of capitalized ORG entities, and initialize an annotator for each one of these keywords. All annotators are added to our combined annotator, rule_based_annotator.

[ ]:
from functools import partial

title_start = [ "Federal", "National", "New", "United", "First", "U.N." ]

for keyword in title_start:
    func = partial(title_detector, keyword=keyword, reverse=False)
    annotator = FunctionAnnotator(keyword + " (start)", func)
    rule_based_annotator.add_annotator(annotator)

We repeat the same process, but this time for keywords that appear at the end of capitalized ORG entities.

[ ]:
title_ending = [
    "Office", "Department", "Association",
    "Corporation", "Army", "Party",
    "Exchange", "Council", "University",
    "Newsroom", "Bureau", "Organisation",
    "Council", "Group", "Inc",
    "Corp", "Ltd"
]

for keyword in title_ending:
    func = partial(title_detector, keyword=keyword, reverse=True)
    annotator = FunctionAnnotator(keyword + " (end)", func)
    rule_based_annotator.add_annotator(annotator)

If you have large lists of keywords that must be labelled as entities on every occurrence ( e.g.Β a list of the names of all Fortune 500 companies ), you may be interested in utilizing a GazetteerAnnotator. The Step by step NER tutorial on skweak’s documentation shows how you can utilize gazetteers to annotate your data.

Annotating with regex#

Until now, all of our rules have manipulated spaCy Doc objects to capture the start and end index of a matching span.

However, it is also possible to capture entities by applying regex patterns directly over the text.

Rubrix has some support for regex operators. If we search for *shire and filter for records annotated as ORG, we will notice that many sports team names end with -shire.

Visualization of regex rules in Rubrix

We can write a rule to capture these entities. This rule can be added to our combined annotator in the same way as all the heuristic rules we have defined so far.

[ ]:
import re

def shire_detector(doc):
    """
    Captures sports team names ending with -shire.
    Examples:
        - Derbyshire
        - Hampshire
        - Worcestershire
    """
    for match in re.finditer("[A-Z][a-z]*shire", doc.text):
        char_start, char_end = match.span()
        span = doc.char_span(char_start, char_end)
        if span:
            yield span.start, span.end, "ORG"
[ ]:
shire_annotator = FunctionAnnotator("shire_team", shire_detector)
rule_based_annotator.add_annotator(shire_annotator)

As long as we return the start, end and label for a span, we are allowed to capture entities in a Doc object in any way we like.

Beyond regex, another way to detect such entities would be to utilize a Matcher object, as defined on spaCy’s documentation.

Logging to Rubrix#

After defining our labelling functions, it’s time to effectively annotate our documents.

First we annotate the development set with gold labels, and add the weak labels of our labelling functions.

[ ]:
from spacy.tokens import Span

def annotate_dataset(dataset, tokens_field="tokens", label_field="ner_tags", gold_field="gold"):
    for row in tqdm(dataset):
        doc = Doc(Vocab(), words=row[tokens_field])
        ner_tags = dataset.features[label_field].feature.int2str(row[label_field])
        offsets = biluo_tags_to_offsets(doc, iob_to_biluo(ner_tags))
        spans = [ doc.char_span(x[0], x[1], label=x[2]) for x in offsets ]
        doc.spans[gold_field] = spans
        yield doc

dev_docs = list(annotate_dataset(conll2003["validation"]))
dev_docs = list(rule_based_annotator.pipe(dev_docs))

Then we will log records to Rubrix, for which any of the labelling functions triggered a weak label, or for which we have a gold annotation. In this way we will be able to quickly visualize any bugs or missing edge cases which may not yet be covered by our labelling functions.

We also add a metadata doc_index that will allow us to group distinct labelling functions for the same document.

[ ]:
def spans_logger(docs, dataset="conll_2003_spans"):
    def unroll_spans(span_list):
        return [ (span.label_, span.start_char, span.end_char) for span in span_list ]

    for idx, doc in enumerate(tqdm(docs)):
        tokens = [token.text for token in doc]

        if tokens == []:
            continue

        predictions, annotations = {}, None
        for labelling_function, span_list in doc.spans.items():
            if labelling_function == "gold":
                annotations = unroll_spans(span_list)
            else:
                predictions[labelling_function] = unroll_spans(span_list)

        # add records for each labelling function, if they made a prediction
        for agent, prediction in predictions.items():
            if prediction:
                yield rb.TokenClassificationRecord(
                    text=" ".join(tokens),
                    tokens=tokens,
                    prediction=prediction,
                    prediction_agent=agent,
                    annotation=annotations,
                    metadata={"doc_index": idx}
                )

        # add records with annotations, for which no labelling function triggered
        if not any(predictions.values()) and annotations:
            yield rb.TokenClassificationRecord(
                text=" ".join(tokens),
                tokens=tokens,
                annotation=annotations,
                metadata={"doc_index": idx}
            )


rb.log(records=spans_logger(dev_docs), name="conll_2003_dev_spans")
Visualization of the Metadata doc index field in Rubrix

3. Evaluate the precision of our rules#

After getting a bird’s-eye view of our annotations with Rubrix, we can use skweak’s LFAnalysis to numerically evaluate the precision of our rules.

We want to eliminate rules from our combinated annotator that have very low precision scores, as this may negatively affect the performance of a model trained on our annotated data.

[38]:
# We evaluate the precision of our heuristic rules

from skweak.analysis import LFAnalysis
import pandas as pd

lf_analysis = LFAnalysis(
    dev_docs,
    ["ORG"]
)

scores = lf_analysis.lf_empirical_scores(
    dev_docs,
    gold_span_name="gold",
    gold_labels=["ORG", "MISC", "PER", "LOC", "O"]
)

def scores_to_df(scores):
    for annotator, label_dict in scores.items():
        for label, metrics_dict in label_dict.items():
            row = {
                "annotator": annotator,
                "label": label,
                "precision": metrics_dict["precision"],
                "recall": metrics_dict["recall"],
                "f1": metrics_dict["f1"]
            }
            yield row

evaluation_df = pd.DataFrame(list(scores_to_df(scores)))\
                    .round(3)\
                    .sort_values(["label", "precision"], ascending=False)\
                    .reset_index(drop=True)
evaluation_df[["annotator", "label", "precision"]]
[38]:
annotator label precision
0 Corp (end) ORG 1.000
1 Organisation (end) ORG 1.000
2 Group (end) ORG 1.000
3 Council (end) ORG 1.000
4 Department (end) ORG 1.000
5 Exchange (end) ORG 1.000
6 Bureau (end) ORG 1.000
7 Corporation (end) ORG 1.000
8 Ltd (end) ORG 1.000
9 sports_results ORG 1.000
10 gold ORG 1.000
11 sports_match ORG 1.000
12 Party (end) ORG 1.000
13 Newsroom (end) ORG 1.000
14 Army (end) ORG 1.000
15 Inc (end) ORG 1.000
16 shire_team ORG 0.982
17 New (start) ORG 0.909
18 U.N. (start) ORG 0.882
19 Association (end) ORG 0.800
20 First (start) ORG 0.800
21 United (start) ORG 0.800
22 Federal (start) ORG 0.714
23 National (start) ORG 0.640

4. Annotate the training data and aggregate the weak labels#

Aggregation#

After carefully considering which rules are appropriate for our dataset, we will annotate the training data and then aggregate our annotations into a single layer.

skweak includes an aggregation model called majority voter. It considers each labelling function as a voter and outputs the most frequent label. We will utilize this majority voter to produce a single set of annotations for our documents, and then we will log the results to Rubrix.

The majority voter is particularly useful when annotating for multiple labels, as in this case the annotations produced by the heuristic rules may not only overlap and but also conflict with each other. However, as we are annotating only for the ORG label, we won’t need the majority voter to resolve any conflicts: it will simply merge the labels from each annotator into the maj_voter field.

[ ]:
# Create the training docs and annotate them with heuristic rules

train_docs = [ Doc(Vocab(), words=row["tokens"]) for row in conll2003["train"] ]
train_docs = list(rule_based_annotator.pipe(train_docs))
[ ]:
# Perform majority voting over the training data

from skweak.aggregation import MajorityVoter
voter = MajorityVoter("maj_voter", labels=["ORG"], sequence_labelling=True)
train_docs = list(voter.pipe(train_docs))
[ ]:
# Log to Rubrix

rb.log(records=spans_logger(train_docs), name="conll_2003_train")

Although here we are using the majority voter in a rather simple way to vote for a single ORG label, it is possible to attribute weights to the vote of each labelling function and even define complex hierarchies between labels. These details are explained in the majority voter documention and code on the skweak repository.

Visualization of the majority voter in Rubrix

Generating the training data#

Our final annotations should be set to the field ents of our spaCy Doc objects.

We set the labels defined by our majority voter for the training set, and the gold labels for the development set.

[ ]:
for doc in train_docs:
    doc.set_ents(doc.spans.get("maj_voter", []))

for doc in dev_docs:
    org_ents = filter(lambda token: token.label_ == "ORG", doc.spans.get("gold", []))
    doc.set_ents(org_ents)

In order to avoid training on an unbalanced dataset, we make sure that we have the same amount of annotated and blank records in our training data.

[ ]:
import random
random.seed(42)
annotated_docs = [ doc for doc in train_docs if doc.ents ]
empty_docs = random.sample([doc for doc in train_docs if not doc.ents], len(annotated_docs))
train_docs_sample = annotated_docs + empty_docs

Finally, we use skweak’s docbin_writer to write our training and development sets to a binary file format that is compatible with spaCy’s command line tools.

[ ]:
# Save the training and development data.

from skweak.utils import docbin_writer

docbin_writer(train_docs_sample, "/tmp/train.spacy")
docbin_writer(dev_docs, "/tmp/dev.spacy")

5. Evaluate the baseline#

Before we train and evaluate our own solution, let’s test a simple model to see what is possible to achieve without weak supervision. In this way we can see if our solution is able to improve on this baseline.

We evaluate the en_core_web_lg spaCy model on CoNLL 2003. The model has been trained on a distinct dataset, OntoNotes 5.0. We do not perform any sort of adaptation on the model and evaluate its zero-shot performance on the development set.

As it can be seen below, our baseline was able to achieve an F-score of 37,3%.

[49]:
from spacy.training import Example
from spacy.scorer import Scorer
import spacy

nlp = spacy.load("en_core_web_lg")
dev_eval_docs = [ nlp(" ".join(row["tokens"])) for row in conll2003["validation"] ]

for doc in dev_eval_docs:
    doc.set_ents(list(filter(lambda x: x.label_ == "ORG", doc.ents)))

scorer_object = Scorer()
scores = scorer_object.score([ Example(dev_eval_docs[i], dev_docs[i]) for i in range(0, len(dev_docs)) ])

pd.DataFrame([{k:v for k,v in scores.items() if k in ["ents_p", "ents_r", "ents_f"]}]).round(3)
[49]:
ents_p ents_r ents_f
0 0.453 0.317 0.373

6. Train and evaluate our model#

Here we train and evaluate a spaCy model on the training data annotated by our heuristic rules.

We initialize our NER model with vectors from our baseline model, en_core_web_lg, and train it for 400 steps.

Our model was able to achieve a F-score of 43.67%, which is a 6,37% improvement over our baseline.

[20]:
# After training our model on data annotated with heuristic rules, we reach an F score of 44%, which is 7% above the baseline.

!spacy init config - --lang en --pipeline ner --optimize accuracy | \
spacy train - \
--training.max_steps 400 \
--system.seed 42 \
--paths.train /tmp/train.spacy \
--paths.dev /tmp/dev.spacy \
--initialize.vectors en_core_web_lg \
--output /tmp/model
β„Ή Saving to output directory: /tmp/model
β„Ή Using CPU

=========================== Initializing pipeline ===========================
[2022-01-27 12:38:58,091] [INFO] Set up nlp object from config
[2022-01-27 12:38:58,102] [INFO] Pipeline: ['tok2vec', 'ner']
[2022-01-27 12:38:58,107] [INFO] Created vocabulary
[2022-01-27 12:38:59,423] [INFO] Added vectors: en_core_web_lg
[2022-01-27 12:39:00,729] [INFO] Finished initializing nlp object
[2022-01-27 12:39:23,582] [INFO] Initialized pipeline components: ['tok2vec', 'ner']
βœ” Initialized pipeline

============================= Training pipeline =============================
β„Ή Pipeline: ['tok2vec', 'ner']
β„Ή Initial learn rate: 0.001
E    #       LOSS TOK2VEC  LOSS NER  ENTS_F  ENTS_P  ENTS_R  SCORE
---  ------  ------------  --------  ------  ------  ------  ------
  0       0          0.00     39.17    2.97    2.26    4.33    0.03
  0     200         69.40   1239.88   36.93   91.72   23.12    0.37
  1     400         15.18    280.61   43.67   78.79   30.20    0.44
βœ” Saved pipeline to output directory
/tmp/model/model-last

Summary#

Writing precise heuristic rules is a key component to weak supervision workflows with skweak. Rubrix makes it easier for us to identify patterns in our data, create new rules and then debug our labelling functions. In this way we are able to accelerate our data annotation pipelines and quickly train new models that score above common zero-shot baselines.

Next steps#

⭐ Rubrix Github repo to stay updated.

πŸ“š Rubrix documentation for more guides and tutorials.

πŸ™‹β€β™€οΈ Join the Rubrix community! A good place to start is the discussion forum.

[ ]: