|
1 | 1 | import re
|
2 |
| -from typing import Dict, List, Tuple |
3 |
| -from collections import defaultdict |
| 2 | +from typing import Dict, List, Tuple, Union |
4 | 3 | from datasets import Dataset, Sequence
|
5 | 4 |
|
6 |
| -from tqdm import tqdm |
7 |
| -import spacy |
8 |
| -from spacy.vocab import Vocab |
9 |
| -from spacy.tokens import Doc |
10 |
| -from spacy.training import iob_to_biluo, biluo_tags_to_offsets, offsets_to_biluo_tags, biluo_to_iob |
| 5 | +from loguru import logger |
11 | 6 |
|
12 | 7 | # These are fixed for encoding the prompt and decoding the output of the LLM
|
13 |
| -LABEL_SEPARATOR = "\n" |
14 |
| -LABEL2ENTITY_SEPARATOR = "->" |
15 |
| -ENTITY_SEPARATOR = ", " |
| 8 | +SPAN_ANNOTATION_TEMPLATE = "{entity} is {label} entity." |
| 9 | +SPAN_ANNOTATION_REGEX = r'(.+) is (.+) entity\.' |
16 | 10 |
|
17 | 11 |
|
18 | 12 | def convert_token_labels_to_spans(
|
19 |
| - dataset: Dataset, token_column: str, label_column: str, expanded_label_mapping: Dict = None |
20 |
| -) -> Tuple[Dataset, List[str]]: |
| 13 | + dataset: Dataset, |
| 14 | + token_column: str, |
| 15 | + label_column: str, |
| 16 | + expanded_label_mapping: Dict = None, |
| 17 | + return_label_options: bool = False |
| 18 | +) -> Union[Dataset, Tuple[Dataset, List[str]]]: |
21 | 19 | """Converts token level labels to spans. Useful for NER tasks to prompt the LLM with natural language labels.
|
22 | 20 |
|
23 | 21 | Args:
|
24 | 22 | dataset (Dataset): huggingface Dataset with token level labels
|
25 | 23 | token_column (str): name of the column with the tokens
|
26 | 24 | label_column (str): name of the column with the token level labels
|
27 | 25 | expanded_label_mapping (Dict): mapping from label ids to label names. Defaults to None.
|
| 26 | + return_label_options (bool): whether to return a list of all possible annotations of the provided dataset |
28 | 27 |
|
29 | 28 | Returns:
|
30 | 29 | Tuple[Dataset, List[str]]: huggingface Dataset with span labels and list of possible labels for the prompt
|
31 | 30 | """
|
32 | 31 | if expanded_label_mapping:
|
| 32 | + if not len(expanded_label_mapping) == len(dataset.features[label_column].feature.names): |
| 33 | + raise ValueError( |
| 34 | + f"Length of expanded label mapping and original number of labels in dataset do not match.\n" |
| 35 | + f"Original labels: {dataset.features[label_column].feature.names}" |
| 36 | + f"Expanded labels: {list(expanded_label_mapping.values())}" |
| 37 | + ) |
33 | 38 | id2label = expanded_label_mapping
|
34 | 39 | elif isinstance(dataset.features[label_column], Sequence):
|
35 | 40 | id2label = dict(enumerate(dataset.features[label_column].feature.names))
|
36 | 41 | else:
|
37 | 42 | raise ValueError("Labels must be a Sequence feature or expanded_label_mapping must be provided.")
|
38 | 43 |
|
39 |
| - new_label_column = f"{label_column}_natural_language" |
40 |
| - label_options = list({label.replace("B-", "").replace("I-", "") for label in id2label.values()}) |
41 |
| - if "O" in label_options: |
42 |
| - label_options.remove("O") |
| 44 | + span_column = "span_annotations" |
43 | 45 |
|
44 |
| - def labels_to_spans(examples): |
45 |
| - bio_tags = [id2label[label] for label in examples[label_column]] |
46 |
| - bilou_tags = iob_to_biluo(bio_tags) |
47 |
| - doc = Doc(Vocab(), words=examples[token_column]) |
48 |
| - offsets = biluo_tags_to_offsets(doc, bilou_tags) |
| 46 | + def labels_to_spans(example): |
| 47 | + span_annotations = [id2label.get(label).replace("B-", "").replace("I-", "") for label in example[label_column]] |
49 | 48 |
|
50 |
| - span_labels = defaultdict(list) |
51 |
| - for start, end, label in offsets: |
52 |
| - span_labels[label].append(doc.text[start:end]) |
| 49 | + annotations_for_prompt = "" |
53 | 50 |
|
54 |
| - examples[token_column] = doc.text |
55 |
| - span_labels = {k: ENTITY_SEPARATOR.join(v) for k, v in span_labels.items()} |
56 |
| - examples[new_label_column] = LABEL_SEPARATOR.join( |
57 |
| - [f"{k} {LABEL2ENTITY_SEPARATOR} {v}" for k, v in span_labels.items()] |
58 |
| - ) |
59 |
| - return examples |
| 51 | + current_entity = None |
| 52 | + current_entity_type = None |
| 53 | + for idx, span_annotation in enumerate(span_annotations): |
| 54 | + if span_annotation == "O": |
| 55 | + if current_entity is not None: |
| 56 | + annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity, |
| 57 | + label=current_entity_type) + "\n" |
| 58 | + current_entity = None |
| 59 | + current_entity_type = None |
| 60 | + continue |
| 61 | + if current_entity is None: |
| 62 | + current_entity = example[token_column][idx] |
| 63 | + current_entity_type = span_annotation |
| 64 | + continue |
| 65 | + if current_entity_type == span_annotation: |
| 66 | + current_entity += " " + example[token_column][idx] |
| 67 | + else: |
| 68 | + annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity, |
| 69 | + label=current_entity_type) + "\n" |
| 70 | + current_entity = example[token_column][idx] |
| 71 | + current_entity_type = span_annotation |
| 72 | + |
| 73 | + if current_entity is not None: |
| 74 | + annotations_for_prompt += SPAN_ANNOTATION_TEMPLATE.format(entity=current_entity, |
| 75 | + label=current_entity_type) + "\n" |
| 76 | + |
| 77 | + example[token_column] = " ".join(example[token_column]) |
| 78 | + example[span_column] = annotations_for_prompt.rstrip("\n") |
| 79 | + return example |
| 80 | + |
| 81 | + dataset = dataset.map(labels_to_spans).remove_columns(label_column).rename_column(span_column, label_column) |
| 82 | + |
| 83 | + if return_label_options: |
| 84 | + # Spans have implicit BIO format, so sequences come in BIO format, we can ignore it |
| 85 | + label_options = list({label.replace("B-", "").replace("I-", "") for label in id2label.values()}) |
| 86 | + |
| 87 | + # Ignore "outside" tokens |
| 88 | + if "O" in label_options: |
| 89 | + label_options.remove("O") |
60 | 90 |
|
61 |
| - dataset = dataset.map(labels_to_spans).remove_columns(label_column).rename_column(new_label_column, label_column) |
| 91 | + return dataset, label_options |
62 | 92 |
|
63 |
| - return dataset, label_options |
| 93 | + return dataset |
64 | 94 |
|
65 | 95 |
|
66 |
| -def convert_spans_to_token_labels(dataset, token_column, label_column, id2label: Dict) -> Dataset: |
67 |
| - """Converts span level labels to token level labels. This is useful for NER tasks to decode the output of the LLM. |
| 96 | +def convert_spans_to_token_labels( |
| 97 | + dataset: Dataset, |
| 98 | + token_column: str, |
| 99 | + label_column: str, |
| 100 | + id2label: Dict, |
| 101 | + annotate_identical_words: bool = False |
| 102 | +) -> Dataset: |
| 103 | + """Converts span level labels to token level labels. |
| 104 | + First, the function extracts all entities with its annotated types. |
| 105 | + Second, if annotations are present, the function converts them to a tag sequence in BIO format. |
| 106 | + If not present, simply return tag sequence of O-tokens. |
| 107 | + This is useful for NER tasks to decode the output of the LLM. |
68 | 108 |
|
69 | 109 | Args:
|
70 | 110 | dataset (Dataset): huggingface Dataset with span level labels
|
71 | 111 | token_column (str): name of the column with the tokens
|
72 | 112 | label_column (str): name of the column with the span level labels
|
73 | 113 | id2label (Dict): mapping from label ids to label names
|
| 114 | + annotate_identical_words (bool): whether to annotate all identical words in a sentence with a found entity |
| 115 | + type |
74 | 116 |
|
75 | 117 | Returns:
|
76 | 118 | Dataset: huggingface Dataset with token level labels in BIO format
|
77 | 119 | """
|
78 |
| - new_label_column = f"{label_column}_tags" |
79 |
| - label2id = {v: k for k, v in id2label.items()} |
80 |
| - labels_no_bio = set([label.replace("B-", "").replace("I-", "") for label in id2label.values()]) |
81 |
| - nlp = spacy.blank("en") |
82 |
| - |
83 |
| - def labels_to_spans(examples): |
84 |
| - texts = examples[token_column] |
85 |
| - str_labels = examples[label_column] |
86 |
| - # goal list of lists of tuples (start, end, label) |
87 |
| - |
88 |
| - tokens = [] |
89 |
| - bio_tags = [] |
90 |
| - for text, str_label in tqdm(zip(texts, str_labels), desc="Converting spans to token labels"): |
91 |
| - spans = [] |
92 |
| - |
93 |
| - if not str_label: |
94 |
| - bio_tags.append([]) |
95 |
| - tokens.append([]) |
96 |
| - continue |
97 |
| - |
98 |
| - try: |
99 |
| - for label_and_entities in str_label.split(LABEL_SEPARATOR): |
100 |
| - label, entities = label_and_entities.split(LABEL2ENTITY_SEPARATOR) |
101 |
| - label = label.strip() |
102 |
| - if label not in labels_no_bio: |
103 |
| - continue |
104 |
| - entities = [entity.strip().lower() for entity in entities.split(ENTITY_SEPARATOR)] |
105 |
| - for entity in set(entities): |
106 |
| - pattern = re.compile(r'\b' + re.escape(entity) + r'\b') |
107 |
| - matches = pattern.finditer(text.lower()) |
108 |
| - for start, end in [(match.start(), match.end()) for match in matches]: |
109 |
| - spans.append((start, end, label)) |
110 |
| - except ValueError: |
111 |
| - bio_tags.append([]) |
112 |
| - tokens.append([]) |
113 |
| - continue |
114 |
| - |
115 |
| - doc = nlp(text) |
116 |
| - |
117 |
| - try: |
118 |
| - tags = [tag if tag != "-" else "O" for tag in biluo_to_iob(offsets_to_biluo_tags(doc, spans))] |
119 |
| - words = [word.text for word in doc] |
120 |
| - if not len(tags) == len(words) or len(tags) == 0 or len(words) == 0: |
121 |
| - tags = [] |
122 |
| - words = [] |
123 |
| - bio_tags.append(tags) |
124 |
| - tokens.append(words) |
125 |
| - except ValueError: |
126 |
| - bio_tags.append([]) |
127 |
| - tokens.append([]) |
128 |
| - continue |
| 120 | + new_label_column = "sequence_tags" |
| 121 | + lower_label2id = {label.lower(): idx for idx, label in id2label.items()} |
| 122 | + |
| 123 | + def labels_to_spans(example): |
| 124 | + span_annotations = example[label_column].split("\n") |
| 125 | + |
| 126 | + ner_tag_tuples = [] |
| 127 | + |
| 128 | + for span_annotation in span_annotations: |
| 129 | + matches = re.match(SPAN_ANNOTATION_REGEX, span_annotation) |
| 130 | + if matches: |
| 131 | + matched_entity = matches.group(1) |
| 132 | + matched_label = matches.group(2) |
| 133 | + |
| 134 | + span_tokens = matched_entity.split(" ") |
| 135 | + span_labels = ["B-" + matched_label if idx == 0 else "B-" + matched_label.lower() |
| 136 | + for idx, _ in enumerate(span_tokens)] |
| 137 | + |
| 138 | + for token, label in zip(span_tokens, span_labels): |
| 139 | + label_id = lower_label2id.get(label.lower()) |
| 140 | + if label_id is None: |
| 141 | + logger.info(f"Entity {token} with label {label} is not in id2label: {id2label}.") |
| 142 | + else: |
| 143 | + ner_tag_tuples.append((token, label_id)) |
| 144 | + else: |
| 145 | + pass |
| 146 | + |
| 147 | + if ner_tag_tuples: |
| 148 | + lower_tokens = example[token_column].lower().split(" ") |
| 149 | + # initialize all tokens with O type |
| 150 | + ner_tags = [0] * len(lower_tokens) |
| 151 | + for reference_token, entity_type_id in ner_tag_tuples: |
| 152 | + if lower_tokens.count(reference_token.lower()) == 0: |
| 153 | + logger.info( |
| 154 | + f"Entity {reference_token} is not found or occurs more than once: {lower_tokens}. " |
| 155 | + f"Thus, setting label to O." |
| 156 | + ) |
| 157 | + elif lower_tokens.count(reference_token.lower()) > 1: |
| 158 | + if annotate_identical_words: |
| 159 | + insert_at_idxs = [index for index, value in enumerate(lower_tokens) |
| 160 | + if value == reference_token.lower()] |
| 161 | + for insert_at_idx in insert_at_idxs: |
| 162 | + ner_tags[insert_at_idx] = entity_type_id |
| 163 | + else: |
| 164 | + logger.info( |
| 165 | + f"Entity {reference_token} occurs more than once: {lower_tokens}. " |
| 166 | + f"Thus, setting label to O." |
| 167 | + ) |
| 168 | + else: |
| 169 | + insert_at_idx = lower_tokens.index(reference_token.lower()) |
| 170 | + ner_tags[insert_at_idx] = entity_type_id |
| 171 | + else: |
| 172 | + ner_tags = [0] * len(example[token_column].split(" ")) |
129 | 173 |
|
130 |
| - examples[token_column] = tokens |
131 |
| - examples[new_label_column] = [[label2id[tag] for tag in tags] for tags in bio_tags] |
| 174 | + example[token_column] = example[token_column].split(" ") |
| 175 | + example[new_label_column] = ner_tags |
132 | 176 |
|
133 |
| - return examples |
| 177 | + return example |
134 | 178 |
|
135 | 179 | dataset = (
|
136 |
| - dataset.map(labels_to_spans, batched=True) |
| 180 | + dataset.map(labels_to_spans) |
137 | 181 | .remove_columns(label_column)
|
138 | 182 | .rename_column(new_label_column, label_column)
|
139 | 183 | )
|
|
0 commit comments