🔫 Zero-shot Named Entity Recognition with Flair¶
TL;DR:¶
You can use Rubrix for analizing and validating the NER predictions from the new zero-shot model provided by the Flair NLP library.
This is useful for quickly bootstrapping a training set (using Rubrix Annotation Mode) as well as integrating with weak-supervision workflows.
Install dependencies¶
[ ]:
%pip install datasets flair -qqq
Setup Rubrix¶
Rubrix, is a free and open-source tool to explore, annotate, and monitor data for NLP projects.
If you are new to Rubrix, check out the Github repository ⭐.
If you have not installed and launched Rubrix, check the Setup and Installation guide.
Once installed, you only need to import Rubrix:
[1]:
import rubrix as rb
Load the wnut_17
dataset¶
In this example, we’ll use a challenging NER dataset, the “WNUT 17: Emerging and Rare entity recognition” dataset, which focuses on unusual, previously-unseen entities in the context of emerging discussions. This dataset is useful for getting a sense of the quality of our zero-shot predictions.
Let’s load the test set from the Hugging Face Hub:
[ ]:
from datasets import load_dataset
dataset = load_dataset("wnut_17", split="test")
[7]:
wnut_labels = ['corporation', 'creative-work', 'group', 'location', 'person', 'product']
Configure Flair TARSTagger¶
Now let’s configure our NER model, following Flair’s documentation.
[ ]:
from flair.models import TARSTagger
from flair.data import Sentence
# Load zero-shot NER tagger
tars = TARSTagger.load('tars-ner')
# Define labels for named entities using wnut labels
labels = wnut_labels
tars.add_and_switch_to_new_task('task 1', labels, label_type='ner')
Let’s test it with one example!
[9]:
sentence = Sentence(" ".join(dataset[0]['tokens']))
[10]:
tars.predict(sentence)
# Creating the prediction entity as a list of tuples (entity, start_char, end_char)
prediction = [
(entity.get_labels()[0].value, entity.start_pos, entity.end_pos)
for entity in sentence.get_spans("ner")
]
prediction
[10]:
[('location', 100, 107)]
Predict over wnut_17
and log into rubrix
¶
Now, let’s log the predictions in rubrix
[ ]:
records = []
for record in dataset.select(range(100)):
input_text = " ".join(record["tokens"])
sentence = Sentence(input_text)
tars.predict(sentence)
prediction = [
(entity.get_labels()[0].value, entity.start_pos, entity.end_pos)
for entity in sentence.get_spans("ner")
]
# Building TokenClassificationRecord
records.append(
rb.TokenClassificationRecord(
text=input_text,
tokens=[token.text for token in sentence],
prediction=prediction,
prediction_agent="tars-ner",
)
)
rb.log(records, name='tars_ner_wnut_17', metadata={"split": "test"})