🔫 Zero-shot Named Entity Recognition with Flair

TL;DR:

You can use Rubrix for analizing and validating the NER predictions from the new zero-shot model provided by the Flair NLP library.

This is useful for quickly bootstrapping a training set (using Rubrix Annotation Mode) as well as integrating with weak-supervision workflows.

wnut zeroshot explore

Install dependencies

[ ]:
%pip install datasets flair -qqq

Setup Rubrix

Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the Github repository ⭐.

If you have not installed and launched Rubrix, check the Setup and Installation guide.

Once installed, you only need to import Rubrix:

[1]:
import rubrix as rb

Load the wnut_17 dataset

In this example, we’ll use a challenging NER dataset, the “WNUT 17: Emerging and Rare entity recognition” dataset, which focuses on unusual, previously-unseen entities in the context of emerging discussions. This dataset is useful for getting a sense of the quality of our zero-shot predictions.

Let’s load the test set from the Hugging Face Hub:

[ ]:
from datasets import load_dataset

dataset = load_dataset("wnut_17", split="test")
[7]:
wnut_labels = ['corporation', 'creative-work', 'group', 'location', 'person', 'product']

Configure Flair TARSTagger

Now let’s configure our NER model, following Flair’s documentation.

[ ]:
from flair.models import TARSTagger
from flair.data import Sentence

# Load zero-shot NER tagger
tars = TARSTagger.load('tars-ner')

# Define labels for named entities using wnut labels
labels = wnut_labels
tars.add_and_switch_to_new_task('task 1', labels, label_type='ner')

Let’s test it with one example!

[9]:
sentence = Sentence(" ".join(dataset[0]['tokens']))
[10]:
tars.predict(sentence)

# Creating the prediction entity as a list of tuples (entity, start_char, end_char)
prediction = [
    (entity.get_labels()[0].value, entity.start_pos, entity.end_pos)
    for entity in sentence.get_spans("ner")
]
prediction
[10]:
[('location', 100, 107)]

Predict over wnut_17 and log into rubrix

Now, let’s log the predictions in rubrix

[ ]:
records = []
for record in dataset.select(range(100)):
    input_text = " ".join(record["tokens"])

    sentence = Sentence(input_text)
    tars.predict(sentence)
    prediction = [
        (entity.get_labels()[0].value, entity.start_pos, entity.end_pos)
        for entity in sentence.get_spans("ner")
    ]

    # Building TokenClassificationRecord
    records.append(
        rb.TokenClassificationRecord(
            text=input_text,
            tokens=[token.text for token in sentence],
            prediction=prediction,
            prediction_agent="tars-ner",
        )
    )

rb.log(records, name='tars_ner_wnut_17', metadata={"split": "test"})