🔫 Zero-shot Named Entity Recognition with Flair


You can use Rubrix for analizing and validating the NER predictions from the new zero-shot model provided by the Flair NLP library.

This is useful for quickly bootstrapping a training set (using Rubrix Annotation Mode) as well as integrating with weak-supervision workflows.

wnut zeroshot explore

Install dependencies

[ ]:
%pip install datasets flair -qqq

Setup Rubrix

Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.

If you are new to Rubrix, check out the Github repository ⭐.

If you have not installed and launched Rubrix, check the Setup and Installation guide.

Once installed, you only need to import Rubrix:

import rubrix as rb

Load the wnut_17 dataset

In this example, we’ll use a challenging NER dataset, the “WNUT 17: Emerging and Rare entity recognition” dataset, which focuses on unusual, previously-unseen entities in the context of emerging discussions. This dataset is useful for getting a sense of the quality of our zero-shot predictions.

Let’s load the test set from the Hugging Face Hub:

[ ]:
from datasets import load_dataset

dataset = load_dataset("wnut_17", split="test")
wnut_labels = ['corporation', 'creative-work', 'group', 'location', 'person', 'product']

Configure Flair TARSTagger

Now let’s configure our NER model, following Flair’s documentation.

[ ]:
from flair.models import TARSTagger
from flair.data import Sentence

# Load zero-shot NER tagger
tars = TARSTagger.load('tars-ner')

# Define labels for named entities using wnut labels
labels = wnut_labels
tars.add_and_switch_to_new_task('task 1', labels, label_type='ner')

Let’s test it with one example!

sentence = Sentence(" ".join(dataset[0]['tokens']))

# Creating the prediction entity as a list of tuples (entity, start_char, end_char)
prediction = [
    (entity.get_labels()[0].value, entity.start_pos, entity.end_pos)
    for entity in sentence.get_spans("ner")
[('location', 100, 107)]

Predict over wnut_17 and log into rubrix

Now, let’s log the predictions in rubrix

[ ]:
records = []
for record in dataset.select(range(100)):
    input_text = " ".join(record["tokens"])

    sentence = Sentence(input_text)
    prediction = [
        (entity.get_labels()[0].value, entity.start_pos, entity.end_pos)
        for entity in sentence.get_spans("ner")

    # Building TokenClassificationRecord
            tokens=[token.text for token in sentence],

rb.log(records, name='tars_ner_wnut_17', metadata={"split": "test"})