Skip to content

Commit 5ec2baf

Browse files
Merge pull request #78 from flairNLP/removal_of_spacy
remove spacy dependency
2 parents 4934866 + 5c300a2 commit 5ec2baf

File tree

5 files changed

+167
-117
lines changed

5 files changed

+167
-117
lines changed

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
datasets
22
farm-haystack>=1.18.0
3-
spacy
43
loguru

src/fabricator/dataset_transformations/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
"replace_class_labels",
88
"convert_token_labels_to_spans",
99
"convert_spans_to_token_labels",
10-
"replace_token_labels",
1110
]
1211

1312
from .question_answering import preprocess_squad_format, postprocess_squad_format, calculate_answer_start
1413
from .text_classification import convert_label_ids_to_texts, get_labels_from_dataset, replace_class_labels
15-
from .token_classification import convert_token_labels_to_spans, convert_spans_to_token_labels, replace_token_labels
14+
from .token_classification import convert_token_labels_to_spans, convert_spans_to_token_labels

src/fabricator/dataset_transformations/token_classification.py

Lines changed: 133 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,139 +1,183 @@
11
import re
2-
from typing import Dict, List, Tuple
3-
from collections import defaultdict
2+
from typing import Dict, List, Tuple, Union
43
from datasets import Dataset, Sequence
54

6-
from tqdm import tqdm
7-
import spacy
8-
from spacy.vocab import Vocab
9-
from spacy.tokens import Doc
10-
from spacy.training import iob_to_biluo, biluo_tags_to_offsets, offsets_to_biluo_tags, biluo_to_iob
5+
from loguru import logger
116

127
# These are fixed for encoding the prompt and decoding the output of the LLM
13-
LABEL_SEPARATOR = "\n"
14-
LABEL2ENTITY_SEPARATOR = "->"
15-
ENTITY_SEPARATOR = ", "
8+
SPAN_ANNOTATION_TEMPLATE = "{entity} is {label} entity."
9+
SPAN_ANNOTATION_REGEX = r'(.+) is (.+) entity\.'
1610

1711

1812
def convert_token_labels_to_spans(
19-
dataset: Dataset, token_column: str, label_column: str, expanded_label_mapping: Dict = None
20-
) -> Tuple[Dataset, List[str]]:
13+
dataset: Dataset,
14+
token_column: str,
15+
label_column: str,
16+
expanded_label_mapping: Dict = None,
17+
return_label_options: bool = False
18+
) -> Union[Dataset, Tuple[Dataset, List[str]]]:
2119
"""Converts token level labels to spans. Useful for NER tasks to prompt the LLM with natural language labels.
2220
2321
Args:
2422
dataset (Dataset): huggingface Dataset with token level labels
2523
token_column (str): name of the column with the tokens
2624
label_column (str): name of the column with the token level labels
2725
expanded_label_mapping (Dict): mapping from label ids to label names. Defaults to None.
26+
return_label_options (bool): whether to return a list of all possible annotations of the provided dataset
2827
2928
Returns:
3029
Tuple[Dataset, List[str]]: huggingface Dataset with span labels and list of possible labels for the prompt
3130
"""
3231
if expanded_label_mapping:
32+
if not len(expanded_label_mapping) == len(dataset.features[label_column].feature.names):
33+
raise ValueError(
34+
f"Length of expanded label mapping and original number of labels in dataset do not match.\n"
35+
f"Original labels: {dataset.features[label_column].feature.names}"
36+
f"Expanded labels: {list(expanded_label_mapping.values())}"
37+
)
3338
id2label = expanded_label_mapping
3439
elif isinstance(dataset.features[label_column], Sequence):
3540
id2label = dict(enumerate(dataset.features[label_column].feature.names))
3641
else:
3742
raise ValueError("Labels must be a Sequence feature or expanded_label_mapping must be provided.")
3843

39-
new_label_column = f"{label_column}_natural_language"
40-
label_options = list({label.replace("B-", "").replace("I-", "") for label in id2label.values()})
41-
if "O" in label_options:
42-
label_options.remove("O")
44+
span_column = "span_annotations"
4345

44-
def labels_to_spans(examples):
45-
bio_tags = [id2label[label] for label in examples[label_column]]
46-
bilou_tags = iob_to_biluo(bio_tags)
47-
doc = Doc(Vocab(), words=examples[token_column])
48-
offsets = biluo_tags_to_offsets(doc, bilou_tags)
46+
def labels_to_spans(example):
47+
span_annotations = [id2label.get(label).replace("B-", "").replace("I-", "") for label in example[label_column]]
4948

50-
span_labels = defaultdict(list)
51-
for start, end, label in offsets:
52-
span_labels[label].append(doc.text[start:end])
49+
annotations_for_prompt = ""
5350

54-
examples[token_column] = doc.text
55-
span_labels = {k: ENTITY_SEPARATOR.join(v) for k, v in span_labels.items()}
56-
examples[new_label_column] = LABEL_SEPARATOR.join(
57-
[f"{k} {LABEL2ENTITY_SEPARATOR} {v}" for k, v in span_labels.items()]
58-
)
59-
return examples
51+
current_entity = None
52+
current_entity_type = None
53+
for idx, span_annotation in enumerate(span_annotations):
54+
if span_annotation == "O":
55+
if current_entity is not None:
56+
annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
57+
label=current_entity_type) + "\n"
58+
current_entity = None
59+
current_entity_type = None
60+
continue
61+
if current_entity is None:
62+
current_entity = example[token_column][idx]
63+
current_entity_type = span_annotation
64+
continue
65+
if current_entity_type == span_annotation:
66+
current_entity += " " + example[token_column][idx]
67+
else:
68+
annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
69+
label=current_entity_type) + "\n"
70+
current_entity = example[token_column][idx]
71+
current_entity_type = span_annotation
72+
73+
if current_entity is not None:
74+
annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity,
75+
label=current_entity_type) + "\n"
76+
77+
example[token_column] = " ".join(example[token_column])
78+
example[span_column] = annotations_for_prompt.rstrip("\n")
79+
return example
80+
81+
dataset = dataset.map(labels_to_spans).remove_columns(label_column).rename_column(span_column, label_column)
82+
83+
if return_label_options:
84+
# Spans have implicit BIO format, so sequences come in BIO format, we can ignore it
85+
label_options = list({label.replace("B-", "").replace("I-", "") for label in id2label.values()})
86+
87+
# Ignore "outside" tokens
88+
if "O" in label_options:
89+
label_options.remove("O")
6090

61-
dataset = dataset.map(labels_to_spans).remove_columns(label_column).rename_column(new_label_column, label_column)
91+
return dataset, label_options
6292

63-
return dataset, label_options
93+
return dataset
6494

6595

66-
def convert_spans_to_token_labels(dataset, token_column, label_column, id2label: Dict) -> Dataset:
67-
"""Converts span level labels to token level labels. This is useful for NER tasks to decode the output of the LLM.
96+
def convert_spans_to_token_labels(
97+
dataset: Dataset,
98+
token_column: str,
99+
label_column: str,
100+
id2label: Dict,
101+
annotate_identical_words: bool = False
102+
) -> Dataset:
103+
"""Converts span level labels to token level labels.
104+
First, the function extracts all entities with its annotated types.
105+
Second, if annotations are present, the function converts them to a tag sequence in BIO format.
106+
If not present, simply return tag sequence of O-tokens.
107+
This is useful for NER tasks to decode the output of the LLM.
68108
69109
Args:
70110
dataset (Dataset): huggingface Dataset with span level labels
71111
token_column (str): name of the column with the tokens
72112
label_column (str): name of the column with the span level labels
73113
id2label (Dict): mapping from label ids to label names
114+
annotate_identical_words (bool): whether to annotate all identical words in a sentence with a found entity
115+
type
74116
75117
Returns:
76118
Dataset: huggingface Dataset with token level labels in BIO format
77119
"""
78-
new_label_column = f"{label_column}_tags"
79-
label2id = {v: k for k, v in id2label.items()}
80-
labels_no_bio = set([label.replace("B-", "").replace("I-", "") for label in id2label.values()])
81-
nlp = spacy.blank("en")
82-
83-
def labels_to_spans(examples):
84-
texts = examples[token_column]
85-
str_labels = examples[label_column]
86-
# goal list of lists of tuples (start, end, label)
87-
88-
tokens = []
89-
bio_tags = []
90-
for text, str_label in tqdm(zip(texts, str_labels), desc="Converting spans to token labels"):
91-
spans = []
92-
93-
if not str_label:
94-
bio_tags.append([])
95-
tokens.append([])
96-
continue
97-
98-
try:
99-
for label_and_entities in str_label.split(LABEL_SEPARATOR):
100-
label, entities = label_and_entities.split(LABEL2ENTITY_SEPARATOR)
101-
label = label.strip()
102-
if label not in labels_no_bio:
103-
continue
104-
entities = [entity.strip().lower() for entity in entities.split(ENTITY_SEPARATOR)]
105-
for entity in set(entities):
106-
pattern = re.compile(r'\b' + re.escape(entity) + r'\b')
107-
matches = pattern.finditer(text.lower())
108-
for start, end in [(match.start(), match.end()) for match in matches]:
109-
spans.append((start, end, label))
110-
except ValueError:
111-
bio_tags.append([])
112-
tokens.append([])
113-
continue
114-
115-
doc = nlp(text)
116-
117-
try:
118-
tags = [tag if tag != "-" else "O" for tag in biluo_to_iob(offsets_to_biluo_tags(doc, spans))]
119-
words = [word.text for word in doc]
120-
if not len(tags) == len(words) or len(tags) == 0 or len(words) == 0:
121-
tags = []
122-
words = []
123-
bio_tags.append(tags)
124-
tokens.append(words)
125-
except ValueError:
126-
bio_tags.append([])
127-
tokens.append([])
128-
continue
120+
new_label_column = "sequence_tags"
121+
lower_label2id = {label.lower(): idx for idx, label in id2label.items()}
122+
123+
def labels_to_spans(example):
124+
span_annotations = example[label_column].split("\n")
125+
126+
ner_tag_tuples = []
127+
128+
for span_annotation in span_annotations:
129+
matches = re.match(SPAN_ANNOTATION_REGEX, span_annotation)
130+
if matches:
131+
matched_entity = matches.group(1)
132+
matched_label = matches.group(2)
133+
134+
span_tokens = matched_entity.split(" ")
135+
span_labels = ["B-" + matched_label if idx == 0 else "B-" + matched_label.lower()
136+
for idx, _ in enumerate(span_tokens)]
137+
138+
for token, label in zip(span_tokens, span_labels):
139+
label_id = lower_label2id.get(label.lower())
140+
if label_id is None:
141+
logger.info(f"Entity {token} with label {label} is not in id2label: {id2label}.")
142+
else:
143+
ner_tag_tuples.append((token, label_id))
144+
else:
145+
pass
146+
147+
if ner_tag_tuples:
148+
lower_tokens = example[token_column].lower().split(" ")
149+
# initialize all tokens with O type
150+
ner_tags = [0] * len(lower_tokens)
151+
for reference_token, entity_type_id in ner_tag_tuples:
152+
if lower_tokens.count(reference_token.lower()) == 0:
153+
logger.info(
154+
f"Entity {reference_token} is not found or occurs more than once: {lower_tokens}. "
155+
f"Thus, setting label to O."
156+
)
157+
elif lower_tokens.count(reference_token.lower()) > 1:
158+
if annotate_identical_words:
159+
insert_at_idxs = [index for index, value in enumerate(lower_tokens)
160+
if value == reference_token.lower()]
161+
for insert_at_idx in insert_at_idxs:
162+
ner_tags[insert_at_idx] = entity_type_id
163+
else:
164+
logger.info(
165+
f"Entity {reference_token} occurs more than once: {lower_tokens}. "
166+
f"Thus, setting label to O."
167+
)
168+
else:
169+
insert_at_idx = lower_tokens.index(reference_token.lower())
170+
ner_tags[insert_at_idx] = entity_type_id
171+
else:
172+
ner_tags = [0] * len(example[token_column].split(" "))
129173

130-
examples[token_column] = tokens
131-
examples[new_label_column] = [[label2id[tag] for tag in tags] for tags in bio_tags]
174+
example[token_column] = example[token_column].split(" ")
175+
example[new_label_column] = ner_tags
132176

133-
return examples
177+
return example
134178

135179
dataset = (
136-
dataset.map(labels_to_spans, batched=True)
180+
dataset.map(labels_to_spans)
137181
.remove_columns(label_column)
138182
.rename_column(new_label_column, label_column)
139183
)

tests/test_dataset_transformations.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,29 +83,27 @@ class TestTransformationsTokenClassification(unittest.TestCase):
8383
"""Testcase for TokenLabelTransformations"""
8484

8585
def setUp(self) -> None:
86-
self.dataset = load_dataset("conll2003", split="train")
86+
self.dataset = load_dataset("conll2003", split="train").select(range(150))
8787

8888
def test_bio_tokens_to_spans(self):
8989
"""Test transformation output only (BIO to spans)"""
9090
dataset, label_options = convert_token_labels_to_spans(
91-
self.dataset, "tokens", "ner_tags"
91+
self.dataset, "tokens", "ner_tags", return_label_options=True
9292
)
9393
self.assertEqual(len(label_options), 4)
9494
self.assertEqual(type(dataset[0]["ner_tags"]), str)
9595
self.assertNotEqual(type(dataset[0]["ner_tags"]), int)
96-
labels = [
97-
spans.split(LABEL2ENTITY_SEPARATOR, 1)[0].strip()
98-
for spans in dataset[0]["ner_tags"].split(LABEL_SEPARATOR)
99-
]
100-
for label in labels:
101-
self.assertIn(label, label_options)
96+
spans = [span for span in dataset[0]["ner_tags"].split("\n")]
97+
for span in spans:
98+
self.assertTrue(any([label in span for label in label_options]))
10299

103100
def test_formatting_with_span_labels(self):
104101
"""Test formatting with span labels"""
105102
dataset, label_options = convert_token_labels_to_spans(
106103
dataset=self.dataset,
107104
token_column="tokens",
108105
label_column="ner_tags",
106+
return_label_options=True
109107
)
110108
fewshot_examples = dataset.select([1, 2, 3])
111109
prompt = BasePrompt(
@@ -115,25 +113,31 @@ def test_formatting_with_span_labels(self):
115113
label_options=label_options,
116114
)
117115
raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples)
118-
self.assertIn("PER -> Peter Blackburn", raw_prompt)
119-
self.assertIn("LOC -> BRUSSELS", raw_prompt)
116+
self.assertIn("Peter Blackburn is PER entity.", raw_prompt)
117+
self.assertIn("BRUSSELS is LOC entity.", raw_prompt)
120118
for label in label_options:
121119
self.assertIn(label, raw_prompt)
122120

123121
def test_expanded_textual_labels(self):
124122
"""Test formatting with expanded textual labels"""
125-
extended_mapping = {"PER": "person", "LOC": "location", "ORG": "organization", "MISC": "misceallaneous"}
126-
id2label = replace_token_labels(dict(enumerate(self.dataset.features["ner_tags"].feature.names)), extended_mapping)
127-
self.assertIn("B-location", id2label.values())
128-
self.assertIn("I-person", id2label.values())
129-
self.assertNotIn("B-LOC", id2label.values())
130-
self.assertNotIn("I-MISC", id2label.values())
123+
expanded_label_mapping = {
124+
0: "O",
125+
1: "B-person",
126+
2: "I-person",
127+
3: "B-location",
128+
4: "I-location",
129+
5: "B-organization",
130+
6: "I-organization",
131+
7: "B-miscellaneous",
132+
8: "I-miscellaneous",
133+
}
131134

132135
dataset, label_options = convert_token_labels_to_spans(
133136
dataset=self.dataset,
134137
token_column="tokens",
135138
label_column="ner_tags",
136-
expanded_label_mapping=id2label
139+
expanded_label_mapping=expanded_label_mapping,
140+
return_label_options=True
137141
)
138142
fewshot_examples = dataset.select([1, 2, 3])
139143
prompt = BasePrompt(
@@ -143,7 +147,7 @@ def test_expanded_textual_labels(self):
143147
label_options=label_options,
144148
)
145149
raw_prompt = prompt.get_prompt_text(label_options, fewshot_examples)
146-
self.assertIn("person -> Peter Blackburn", raw_prompt)
150+
self.assertIn("Peter Blackburn is person entity.", raw_prompt)
147151
self.assertNotIn("PER", raw_prompt)
148152
for label in label_options:
149153
self.assertIn(label, raw_prompt)
@@ -154,9 +158,10 @@ def test_textual_labels_to_label_ids(self):
154158
dataset=self.dataset,
155159
token_column="tokens",
156160
label_column="ner_tags",
161+
return_label_options=True
157162
)
158163
id2label = dict(enumerate(self.dataset.features["ner_tags"].feature.names))
159-
self.assertEqual(dataset[0]["ner_tags"], "ORG -> EU\nMISC -> German, British")
164+
self.assertEqual(dataset[0]["ner_tags"], "EU is ORG entity.\nGerman is MISC entity.\nBritish is MISC entity.")
160165
dataset = dataset.select(range(10))
161166
dataset = convert_spans_to_token_labels(
162167
dataset=dataset,

0 commit comments

Comments
 (0)