Coverage for src/dataknobs_xization/annotations.py: 31%
493 statements
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-18 17:41 -0700
« prev ^ index » next coverage.py v7.11.3, created at 2025-11-18 17:41 -0700
1"""Text annotation data structures and interfaces.
3Provides classes for managing text annotations with metadata, including
4position tracking, annotation types, and derived annotation columns.
5"""
7import json
8from abc import ABC, abstractmethod
9from collections.abc import Callable
10from typing import Any, Dict, List, Set
12import numpy as np
13import pandas as pd
15import dataknobs_structures.document as dk_doc
17# Key annotations column name constants for use across annotation interfaces
18KEY_START_POS_COL = "start_pos"
19KEY_END_POS_COL = "end_pos"
20KEY_TEXT_COL = "text"
21KEY_ANN_TYPE_COL = "ann_type"
24class AnnotationsMetaData(dk_doc.MetaData):
25 """Container for annotations meta-data, identifying key column names.
27 NOTE: this object contains only information about annotation column names
28 and not annotation table values.
29 """
31 def __init__(
32 self,
33 start_pos_col: str = KEY_START_POS_COL,
34 end_pos_col: str = KEY_END_POS_COL,
35 text_col: str = KEY_TEXT_COL,
36 ann_type_col: str = KEY_ANN_TYPE_COL,
37 sort_fields: List[str] = (KEY_START_POS_COL, KEY_END_POS_COL),
38 sort_fields_ascending: List[bool] = (True, False),
39 **kwargs: Any
40 ):
41 """Initialize with key (and more) column names and info.
43 Key column types:
44 * start_pos
45 * end_pos
46 * text
47 * ann_type
49 Note:
50 Actual table columns can be named arbitrarily, BUT interactions
51 through annotations classes and interfaces relating to the "key"
52 columns must use the key column constants.
54 Args:
55 start_pos_col: Col name for the token starting position.
56 end_pos_col: Col name for the token ending position.
57 text_col: Col name for the token text.
58 ann_type_col: Col name for the annotation types.
59 sort_fields: The col types relevant for sorting annotation rows.
60 sort_fields_ascending: To specify sort order of sort_fields.
61 **kwargs: More column types mapped to column names.
62 """
63 super().__init__(
64 {
65 KEY_START_POS_COL: start_pos_col,
66 KEY_END_POS_COL: end_pos_col,
67 KEY_TEXT_COL: text_col,
68 KEY_ANN_TYPE_COL: ann_type_col,
69 },
70 **kwargs,
71 )
72 self.sort_fields = list(sort_fields)
73 self.ascending = sort_fields_ascending
75 @property
76 def start_pos_col(self) -> str:
77 """Get the column name for the token starting postition"""
78 return self.data[KEY_START_POS_COL]
80 @property
81 def end_pos_col(self) -> str:
82 """Get the column name for the token ending position"""
83 return self.data[KEY_END_POS_COL]
85 @property
86 def text_col(self) -> str:
87 """Get the column name for the token text"""
88 return self.data[KEY_TEXT_COL]
90 @property
91 def ann_type_col(self) -> str:
92 """Get the column name for the token annotation type"""
93 return self.data[KEY_ANN_TYPE_COL]
95 def get_col(self, col_type: str, missing: str = None) -> str:
96 """Get the name of the column having the given type (including key column
97 types but not derived,) or get the missing value.
99 Args:
100 col_type: The type of column name to get.
101 missing: The value to return for unknown column types.
103 Returns:
104 The column name or the missing value.
105 """
106 return self.get_value(col_type, missing)
108 def sort_df(self, an_df: pd.DataFrame) -> pd.DataFrame:
109 """Sort an annotations dataframe according to this metadata.
111 Args:
112 an_df: An annotations dataframe.
114 Returns:
115 The sorted annotations dataframe.
116 """
117 if self.sort_fields is not None:
118 an_df = an_df.sort_values(self.sort_fields, ascending=self.ascending)
119 return an_df
122class DerivedAnnotationColumns(ABC):
123 """Interface for injecting derived columns into AnnotationsMetaData."""
125 @abstractmethod
126 def get_col_value(
127 self,
128 metadata: AnnotationsMetaData,
129 col_type: str,
130 row: pd.Series,
131 missing: str = None,
132 ) -> str:
133 """Get the value of the column in the given row derived from col_type.
135 Args:
136 metadata: The AnnotationsMetaData.
137 col_type: The type of column value to derive.
138 row: A row from which to get the value.
139 missing: The value to return for unknown or missing column.
141 Returns:
142 The row value or the missing value.
143 """
144 raise NotImplementedError
147class AnnotationsRowAccessor:
148 """A class that accesses row data according to the metadata and derived cols."""
150 def __init__(
151 self, metadata: AnnotationsMetaData, derived_cols: DerivedAnnotationColumns = None
152 ):
153 """Initialize AnnotationsRowAccessor.
155 Args:
156 metadata: The metadata for annotation columns.
157 derived_cols: A DerivedAnnotationColumns instance for injecting
158 derived columns.
159 """
160 self.metadata = metadata
161 self.derived_cols = derived_cols
163 def get_col_value(
164 self,
165 col_type: str,
166 row: pd.Series,
167 missing: str = None,
168 ) -> str:
169 """Get the value of the column in the given row with the given type.
171 This gets the value from the first existing column in the row from:
172 * The metadata.get_col(col_type) column
173 * col_type itself
174 * The columns derived from col_type
176 Args:
177 col_type: The type of column value to get.
178 row: A row from which to get the value.
179 missing: The value to return for unknown or missing column.
181 Returns:
182 The row value or the missing value.
183 """
184 value = missing
185 col = self.metadata.get_col(col_type, None)
186 if col is None or col not in row.index:
187 if col_type in self.metadata.data:
188 value = row[col_type]
189 elif self.derived_cols is not None:
190 value = self.derived_cols.get_col_value(self.metadata, col_type, row, missing)
191 else:
192 value = row[col]
193 return value
196class Annotations:
197 """DAO for collecting and managing a table of annotations, where each row
198 carries annotation information for an input token.
200 The data in this class is maintained either as a list of dicts, each dict
201 representing a "row," or as a pandas DataFrame, depending on the latest
202 access. Changes in either the lists or dataframe will be reflected in the
203 alternate data structure.
204 """
206 def __init__(
207 self,
208 metadata: AnnotationsMetaData,
209 df: pd.DataFrame = None,
210 ):
211 """Construct as empty or initialize with the dataframe form.
213 Args:
214 metadata: The annotations metadata.
215 df: A dataframe with annotation records.
216 """
217 self.metadata = metadata
218 self._annotations_list = None
219 self._df = df
221 @property
222 def ann_row_dicts(self) -> List[Dict[str, Any]]:
223 """Get the annotations as a list of dictionaries."""
224 if self._annotations_list is None:
225 self._annotations_list = self._build_list()
226 return self._annotations_list
228 @property
229 def df(self) -> pd.DataFrame:
230 """Get the annotations as a pandas dataframe."""
231 if self._df is None:
232 self._df = self._build_df()
233 return self._df
235 def clear(self) -> pd.DataFrame:
236 """Clear/empty out all annotations, returning the annotations df"""
237 rv = self.df
238 self._df = None
239 self._annotations_list = None
240 return rv
242 def is_empty(self) -> bool:
243 return (self._df is None or len(self._df) == 0) and (
244 self._annotations_list is None or len(self._annotations_list) == 0
245 )
247 def add_dict(self, annotation: Dict[str, Any]):
248 """Add the annotation dict."""
249 self.ann_row_dicts.append(annotation)
251 def add_dicts(self, annotations: List[Dict[str, Any]]):
252 """Add the annotation dicts."""
253 self.ann_row_dicts.extend(annotations)
255 def add_df(self, an_df: pd.DataFrame):
256 """Add (concatentate) the annotation dataframe to the current annotations."""
257 df = self.metadata.sort_df(pd.concat([self.df, an_df]))
258 self.set_df(df)
260 def _build_list(self) -> List[Dict[str, Any]]:
261 """Build the annotations list from the dataframe."""
262 alist = None
263 if self._df is not None:
264 alist = self._df.to_dict(orient="records")
265 self._df = None
266 return alist if alist is not None else []
268 def _build_df(self) -> pd.DataFrame:
269 """Get the annotations as a df."""
270 df = None
271 if self._annotations_list is not None:
272 if len(self._annotations_list) > 0:
273 df = self.metadata.sort_df(pd.DataFrame(self._annotations_list))
274 self._annotations_list = None
275 return df
277 def set_df(self, df: pd.DataFrame):
278 """Set (or reset) this annotation's dataframe.
280 Args:
281 df: The new annotations dataframe.
282 """
283 self._df = df
284 self._annotations_list = None
287class AnnotationsBuilder:
288 """A class for building annotations."""
290 def __init__(
291 self,
292 metadata: AnnotationsMetaData,
293 data_defaults: Dict[str, Any],
294 ):
295 """Initialize AnnotationsBuilder.
297 Args:
298 metadata: The annotations metadata.
299 data_defaults: Dict[ann_colname, default_value] with default
300 values for annotation columns.
301 """
302 self.metadata = metadata if metadata is not None else AnnotationsMetaData()
303 self.data_defaults = data_defaults
305 def build_annotation_row(
306 self, start_pos: int, end_pos: int, text: str, ann_type: str, **kwargs: Any
307 ) -> Dict[str, Any]:
308 """Build an annotation row with the mandatory key values and those from
309 the remaining keyword arguments.
311 For those kwargs whose names match metadata column names, override the
312 data_defaults and add remaining data_default attributes.
314 Args:
315 start_pos: The token start position.
316 end_pos: The token end position.
317 text: The token text.
318 ann_type: The annotation type.
319 **kwargs: Additional keyword arguments for extra annotation fields.
321 Returns:
322 The result row dictionary.
323 """
324 return self.do_build_row(
325 {
326 self.metadata.start_pos_col: start_pos,
327 self.metadata.end_pos_col: end_pos,
328 self.metadata.text_col: text,
329 self.metadata.ann_type_col: ann_type,
330 },
331 **kwargs,
332 )
334 def do_build_row(self, key_fields: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
335 """Do the row building with the key fields, followed by data defaults,
336 followed by any extra kwargs.
338 Args:
339 key_fields: The dictionary of key fields.
340 **kwargs: Any extra fields to add.
342 Returns:
343 The constructed row dictionary.
344 """
345 result = {}
346 result.update(key_fields)
347 if self.data_defaults is not None:
348 # Add data_defaults
349 result.update(self.data_defaults)
350 if kwargs is not None:
351 # Override with extra kwargs
352 result.update(kwargs)
353 return result
356class RowData:
357 """A wrapper for an annotation row (pd.Series) to facilitate e.g., grouping."""
359 def __init__(
360 self,
361 metadata: AnnotationsMetaData,
362 row: pd.Series,
363 ):
364 self.metadata = metadata
365 self.row = row
367 @property
368 def loc(self):
369 return self.row.name
371 def __repr__(self) -> str:
372 return f'[{self.start_pos}:{self.end_pos})"{self.text}"'
374 @property
375 def start_pos(self) -> int:
376 return self.row[self.metadata.start_pos_col]
378 @property
379 def end_pos(self) -> int:
380 return self.row[self.metadata.end_pos_col]
382 @property
383 def text(self) -> str:
384 return self.row[self.metadata.text_col]
386 def is_subset(self, other_row: "RowData") -> bool:
387 """Determine whether this row's span is a subset of the other.
389 Args:
390 other_row: The other row.
392 Returns:
393 True if this row's span is a subset of the other row's span.
394 """
395 return self.start_pos >= other_row.start_pos and self.end_pos <= other_row.end_pos
397 def is_subset_of_any(self, other_rows: List["RowData"]) -> bool:
398 """Determine whether this row is a subset of any of the others
399 according to text span coverage.
401 Args:
402 other_rows: The rows to test for this to be a subset of any.
404 Returns:
405 True if this row is a subset of any of the other rows.
406 """
407 result = False
408 for other_row in other_rows:
409 if self.is_subset(other_row):
410 result = True
411 break
412 return result
415class AnnotationsGroup:
416 """Container for annotation rows that belong together as a (consistent) group.
418 NOTE: An instance will only accept rows on condition of consistency per its
419 acceptance function.
420 """
422 def __init__(
423 self,
424 row_accessor: AnnotationsRowAccessor,
425 field_col_type: str,
426 accept_fn: Callable[["AnnotationsGroup", RowData], bool],
427 group_type: str = None,
428 group_num: int = None,
429 valid: bool = True,
430 autolock: bool = False,
431 ):
432 """Initialize AnnotationsGroup.
434 Args:
435 row_accessor: The annotations row_accessor.
436 field_col_type: The col_type for the group field_type for retrieval
437 using the annotations row accessor.
438 accept_fn: A fn(g, row_data) that returns True to accept the row
439 data into this group g, or False to reject the row. If None, then
440 all rows are always accepted.
441 group_type: An optional (override) type for identifying this group.
442 group_num: An optional number for identifying this group.
443 valid: True if the group is valid, or False if not.
444 autolock: True to automatically lock this group when (1) at
445 least one row has been added and (2) a row is rejected.
446 """
447 self.rows = [] # List[RowData]
448 self.row_accessor = row_accessor
449 self.field_col_type = field_col_type
450 self.accept_fn = accept_fn
451 self._group_type = group_type
452 self._group_num = group_num
453 self._valid = valid
454 self._autolock = autolock
455 self._locked = False
456 self._locs = None # track loc's for recognizing dupes
457 self._key = None # a hash key using the _locs
458 self._df = None
459 self._ann_type = None
461 @property
462 def is_locked(self) -> bool:
463 """Get whether this group is locked from adding more rows."""
464 return self._locked
466 @is_locked.setter
467 def is_locked(self, value: bool):
468 """Set this group as locked (value=True) or unlocked (value=False) to
469 allow or disallow more rows from being added regardless of the accept
470 function.
472 Note that while unlocked only rows that pass the accept function will
473 be added.
475 Args:
476 value: True to lock or False to unlock this group.
477 """
478 self._locked = value
480 @property
481 def is_valid(self) -> bool:
482 """Get whether this group is currently marked as valid."""
483 return self._valid
485 @is_valid.setter
486 def is_valid(self, value: bool):
487 """Mark this group as valid (value=True) or invalid (value=False).
489 Args:
490 value: True for valid or False for invalid.
491 """
492 self._valid = value
494 @property
495 def autolock(self) -> bool:
496 """Get whether this group is currently set to autolock."""
497 return self._autolock
499 @autolock.setter
500 def autolock(self, value: bool):
501 """Set this group to autolock (True) or not (False).
503 Args:
504 value: True to autolock or False to not autolock.
505 """
506 self._autolock = value
508 def __repr__(self):
509 return json.dumps(self.to_dict())
511 @property
512 def size(self) -> int:
513 """Get the number of rows in this group."""
514 return len(self.rows)
516 @property
517 def group_type(self) -> str:
518 """Get this group's type, which is either an "override" value that has
519 been set, or the "ann_type" value of the first row added.
520 """
521 return self._group_type if self._group_type is not None else self.ann_type
523 @group_type.setter
524 def group_type(self, value: str):
525 """Set this group's type"""
526 self._group_type = value
528 @property
529 def group_num(self) -> int:
530 """Get this group's number"""
531 return self._group_num
533 @group_num.setter
534 def group_num(self, value: int):
535 """Set this group's num"""
536 self._group_num = value
538 @property
539 def df(self) -> pd.DataFrame:
540 """Get this group as a dataframe"""
541 if self._df is None:
542 self._df = pd.DataFrame([r.row for r in self.rows])
543 return self._df
545 @property
546 def ann_type(self) -> str:
547 """Get this record's annotation type"""
548 return self._ann_type
550 @property
551 def text(self) -> str:
552 return " ".join([row.text for row in self.rows])
554 @property
555 def locs(self) -> List[int]:
556 if self._locs is None:
557 self._locs = [r.loc for r in self.rows]
558 return self._locs
560 @property
561 def key(self) -> str:
562 """A hash key for this group."""
563 if self._key is None:
564 self._key = "_".join([str(x) for x in sorted(self.locs)])
565 return self._key
567 def copy(self) -> "AnnotationsGroup":
568 result = AnnotationsGroup(
569 self.row_accessor,
570 self.field_col_type,
571 self.accept_fn,
572 group_type=self.group_type,
573 group_num=self.group_num,
574 valid=self.is_valid,
575 autolock=self.autolock,
576 )
577 result.rows = self.rows.copy()
578 result._locked = self._locked # pylint: disable=protected-access
579 result._ann_type = self._ann_type # pylint: disable=protected-access
581 def add(self, rowdata: RowData) -> bool:
582 """Add the row if the group is not locked and the row belongs in this
583 group, or return False.
585 If autolock is True and a row fails to be added (after the first
586 row has been added,) "lock" the group and refuse to accept any more
587 rows.
589 Args:
590 rowdata: The row to add.
592 Returns:
593 True if the row belongs and was added; otherwise, False.
594 """
595 result = False
596 if self._locked:
597 return result
599 if self.accept_fn is None or self.accept_fn(self, rowdata):
600 self.rows.append(rowdata)
601 self._df = None
602 self._locs = None
603 self._key = None
604 if self._ann_type is None:
605 self._ann_type = self.row_accessor.get_col_value(
606 KEY_ANN_TYPE_COL,
607 rowdata.row,
608 missing=None,
609 )
610 result = True
612 if not result and self.size > 0 and self.autolock:
613 self._locked = True
615 return result
617 def to_dict(self) -> Dict[str, str]:
618 """Get this group (record) as a dictionary of field type to text values."""
619 return {self.row_accessor.get_col_value(self.field_col_type): row.text for row in self.rows}
621 def is_subset(self, other: "AnnotationsGroup") -> bool:
622 """Determine whether the this group's text is contained within the others.
624 Args:
625 other: The other group.
627 Returns:
628 True if this group's text is contained within the other group.
629 """
630 result = True
631 for my_row in self.rows:
632 if not my_row.is_subset_of_any(other.rows):
633 result = False
634 break
635 return result
637 def is_subset_of_any(self, groups: List["AnnotationsGroup"]) -> "AnnotationsGroup":
638 """Determine whether this group is a subset of any of the given groups.
640 Args:
641 groups: List of annotation groups.
643 Returns:
644 The first AnnotationsGroup that this group is a subset of, or None.
645 """
646 result = None
647 for other_group in groups:
648 if self.is_subset(other_group):
649 result = other_group
650 break
651 return result
653 def remove_row(
654 self,
655 row_idx: int,
656 ) -> RowData:
657 """Remove the row from this group and optionally update the annotations
658 accordingly.
660 Args:
661 row_idx: The positional index of the row to remove.
663 Returns:
664 The removed row data instance.
665 """
666 rowdata = self.rows.pop(row_idx)
668 # Reset cached values
669 self._df = None
670 self._locs = None
671 self._key = None
673 return rowdata
676class MergeStrategy(ABC):
677 """A merge strategy to be injected based on entity types being merged."""
679 @abstractmethod
680 def merge(self, group: AnnotationsGroup) -> List[Dict[str, Any]]:
681 """Process the annotations in the given annotations group, returning the
682 group's merged annotation dictionaries.
683 """
684 raise NotImplementedError
687class PositionalAnnotationsGroup(AnnotationsGroup):
688 """Container for annotations that either overlap with each other or don't."""
690 def __init__(self, overlap: bool, rectype: str = None, gnum: int = -1):
691 """Initialize PositionalAnnotationsGroup.
693 Args:
694 overlap: If False, then only accept rows that don't overlap; else
695 only accept rows that do overlap.
696 rectype: The record type.
697 gnum: The group number.
698 """
699 super().__init__(None, None, None, group_type=rectype, group_num=gnum)
700 self.overlap = overlap
701 self.start_pos = -1
702 self.end_pos = -1
704 def __repr__(self) -> str:
705 return f'nrows={len(self.rows)}[{self.start_pos},{self.end_pos})"{self.entity_text}"'
707 @property
708 def entity_text(self) -> str:
709 jstr = " | " if self.overlap else " "
710 return jstr.join(r.entity_text for r in self.rows)
712 def belongs(self, rowdata: RowData) -> bool:
713 """Determine if the row belongs in this instance based on its overlap
714 or not.
716 Args:
717 rowdata: The rowdata to test.
719 Returns:
720 True if the rowdata belongs in this instance.
721 """
722 result = True # Anything belongs to an empty group
723 if len(self.rows) > 0:
724 start_overlaps = self._is_in_bounds(rowdata.start_pos)
725 end_overlaps = self._is_in_bounds(rowdata.end_pos - 1)
726 result = start_overlaps or end_overlaps
727 if not self.overlap:
728 result = not result
729 if result:
730 if self.start_pos < 0:
731 self.start_pos = rowdata.start_pos
732 self.end_pos = rowdata.end_pos
733 else:
734 self.start_pos = min(self.start_pos, rowdata.start_pos)
735 self.end_pos = max(self.end_pos, rowdata.end_pos)
736 return result
738 def _is_in_bounds(self, char_pos):
739 return char_pos >= self.start_pos and char_pos < self.end_pos
741 def copy(self) -> "PositionalAnnotationsGroup":
742 result = PositionalAnnotationsGroup(self.overlap)
743 result.start_pos = self.start_pos
744 result.end_pos = self.end_pos
745 result.rows = self.rows.copy()
746 return result
748 # TODO: Add comparison and merge functions
751class OverlapGroupIterator:
752 """Given:
753 * annotation rows (dataframe)
754 * in order sorted by
755 * start_pos (increasing for input order), and
756 * end_pos (decreasing for longest spans first)
757 Collect:
758 * overlapping consecutive annotations
759 * for processing
760 """
762 def __init__(self, an_df: pd.DataFrame):
763 """Initialize OverlapGroupIterator.
765 Args:
766 an_df: An annotations.as_df DataFrame, sliced and sorted.
767 """
768 self.an_df = an_df
769 self._cur_iter = None
770 self._queued_row_data = None
771 self.cur_group = None
772 self.reset()
774 def next_group(self) -> AnnotationsGroup:
775 group = None
776 if self.has_next:
777 group = PositionalAnnotationsGroup(True)
778 while self.has_next and group.belongs(self._queued_row_data):
779 self._queue_next()
780 self.cur_group = group
781 return group
783 def reset(self):
784 self._cur_iter = self.an_df.iterrows()
785 self._queue_next()
786 self.cur_group = None
788 @property
789 def has_next(self) -> bool:
790 return self._queued_row_data is not None
792 def _queue_next(self):
793 try:
794 _loc, row = next(self._cur_iter)
795 self._queued_row_data = RowData(None, row) # TODO: add metadata
796 except StopIteration:
797 self._queued_row_data = None
800def merge(
801 annotations: Annotations,
802 merge_strategy: MergeStrategy,
803) -> Annotations:
804 """Merge the overlapping groups according to the given strategy."""
805 og_iter = OverlapGroupIterator(annotations.as_df)
806 result = Annotations(annotations.metadata)
807 while og_iter.has_next:
808 og = og_iter.next_group()
809 result.add_dicts(merge_strategy.merge(og))
810 return result
813class AnnotationsGroupList:
814 """Container for a list of annotation groups."""
816 def __init__(
817 self,
818 groups: List[AnnotationsGroup] = None,
819 accept_fn: Callable[["AnnotationsGroupList", AnnotationsGroup], bool] = lambda lst, g: lst.size
820 == 0
821 or not g.is_subset_of_any(lst.groups),
822 ):
823 """Initialize AnnotationsGroupList.
825 Args:
826 groups: The initial groups for this list.
827 accept_fn: A fn(lst, g) that returns True to accept the group, g,
828 into this list, lst, or False to reject the group. If None, then all
829 groups are always accepted. The default function will reject any
830 group that is a subset of any existing group in the list.
831 """
832 self.groups = groups if groups is not None else []
833 self.accept_fn = accept_fn
834 self._coverage = None
836 def __repr__(self) -> str:
837 return str(self.groups)
839 @property
840 def size(self) -> int:
841 """Get the number of groups in this list"""
842 return len(self.groups)
844 @property
845 def coverage(self) -> int:
846 """Get the total number of (token) rows covered by the groups"""
847 if self._coverage is None:
848 locs = set()
849 for group in self.groups:
850 locs.update(set(group.locs))
851 self._coverage = len(locs)
852 return self._coverage
854 @property
855 def df(self) -> pd.DataFrame:
856 return pd.DataFrame([r.row for g in self.groups for r in g.rows])
858 def copy(self) -> "AnnotationsGroupList":
859 result = AnnotationsGroupList(self.groups.copy(), accept_fn=self.accept_fn)
860 result._coverage = self._coverage # pylint: disable=protected-access
861 return result
863 def add(self, group: AnnotationsGroup) -> bool:
864 """Add the group if it belongs in this group list or return False.
866 Args:
867 group: The group to add.
869 Returns:
870 True if the group belongs and was added; otherwise, False.
871 """
872 result = False
873 if self.accept_fn is None or self.accept_fn(self, group):
874 self.groups.append(group)
875 self._coverage = None
876 result = True
877 return result
879 def is_subset(self, other: "AnnotationsGroupList") -> bool:
880 """Determine whether the this group's text spans are contained within all
881 of the other's.
883 Args:
884 other: The other group list.
886 Returns:
887 True if this group list is a subset of the other group list.
888 """
889 result = True
890 for my_group in self.groups:
891 if not my_group.is_subset_of_any(other.groups):
892 result = False
893 break
894 return result
897class AnnotatedText(dk_doc.Text):
898 """A Text object that manages its own annotations."""
900 def __init__(
901 self,
902 text_str: str,
903 metadata: dk_doc.TextMetaData = None,
904 annots: Annotations = None,
905 bookmarks: Dict[str, pd.DataFrame] = None,
906 text_obj: dk_doc.Text = None,
907 annots_metadata: AnnotationsMetaData = None,
908 ):
909 """Initialize AnnotatedText.
911 Args:
912 text_str: The text string.
913 metadata: The text's metadata.
914 annots: The annotations.
915 bookmarks: The annotation bookmarks.
916 text_obj: A text_obj to override text_str and metadata initialization.
917 annots_metadata: Override for default annotations metadata
918 (NOTE: ineffectual if an annots instance is provided.)
919 """
920 super().__init__(
921 text_obj.text if text_obj is not None else text_str,
922 text_obj.metadata if text_obj is not None else metadata,
923 )
924 self._annots = annots
925 self._bookmarks = bookmarks
926 self._annots_metadata = annots_metadata
928 @property
929 def annotations(self) -> Annotations:
930 """Get the this object's annotations"""
931 if self._annots is None:
932 self._annots = Annotations(self._annots_metadata or AnnotationsMetaData())
933 return self._annots
935 @property
936 def bookmarks(self) -> Dict[str, pd.DataFrame]:
937 """Get this object's bookmarks"""
938 if self._bookmarks is None:
939 self._bookmarks = {}
940 return self._bookmarks
942 def get_text(
943 self,
944 annot2mask: Dict[str, str] = None,
945 annot_df: pd.DataFrame = None,
946 text: str = None,
947 ) -> str:
948 """Get the text object's string, masking if indicated.
950 Args:
951 annot2mask: Mapping from annotation column (e.g., _num or
952 _recsnum) to the replacement character(s) in the input text
953 for masking already managed input.
954 annot_df: Override annotations dataframe.
955 text: Override text.
957 Returns:
958 The (masked) text.
959 """
960 if annot2mask is None:
961 return self.text
962 # Apply the mask
963 text_s = self.get_text_series(text=text) # no padding
964 if annot2mask is not None:
965 annot_df = self.annotations.as_df
966 text_s = self._apply_mask(text_s, annot2mask, annot_df)
967 return "".join(text_s)
969 def get_text_series(
970 self,
971 pad_len: int = 0,
972 text: str = None,
973 ) -> pd.Series:
974 """Get the input text as a (padded) pandas series.
976 Args:
977 pad_len: The number of spaces to pad both front and back.
978 text: Override text.
980 Returns:
981 The (padded) pandas series of input characters.
982 """
983 if text is None:
984 text = self.text
985 return pd.Series(list(" " * pad_len + text + " " * pad_len))
987 def get_annot_mask(
988 self,
989 annot_col: str,
990 pad_len: int = 0,
991 annot_df: pd.DataFrame = None,
992 text: str = None,
993 ) -> pd.Series:
994 """Get a True/False series for the input such that start to end positions
995 for rows where the the annotation column is non-null and non-empty are
996 True.
998 Args:
999 annot_col: The annotation column identifying chars to mask.
1000 pad_len: The number of characters to pad the mask with False
1001 values at both the front and back.
1002 annot_df: Override annotations dataframe.
1003 text: Override text.
1005 Returns:
1006 A pandas Series where annotated input character positions
1007 are True and non-annotated positions are False.
1008 """
1009 if annot_df is None:
1010 annot_df = self.annotations.as_df
1011 if text is None:
1012 text = self.text
1013 textlen = len(text)
1014 return self._get_annot_mask(annot_df, textlen, annot_col, pad_len=pad_len)
1016 @staticmethod
1017 def _get_annot_mask(
1018 annot_df: pd.DataFrame,
1019 textlen: int,
1020 annot_col: str,
1021 pad_len: int = 0,
1022 ) -> pd.Series:
1023 """Get a True/False series for the input such that start to end positions
1024 for rows where the the annotation column is non-null and non-empty are
1025 True.
1027 Args:
1028 annot_df: The annotations dataframe.
1029 textlen: The length of the input text.
1030 annot_col: The annotation column identifying chars to mask.
1031 pad_len: The number of characters to pad the mask with False
1032 values at both the front and back.
1034 Returns:
1035 A pandas Series where annotated input character positions
1036 are True and non-annotated positions are False.
1037 """
1038 mask = None
1039 df = annot_df
1040 if annot_col in df.columns:
1041 df = df[np.logical_and(df[annot_col].notna(), df[annot_col] != "")]
1042 mask = pd.Series([False] * textlen)
1043 for _, row in df.iterrows():
1044 mask.loc[row["start_pos"] + pad_len : row["end_pos"] - 1 + pad_len] = True
1045 return mask
1047 def _apply_mask(
1048 self,
1049 text_s: pd.Series,
1050 annot2mask: Dict[str, str],
1051 annot_df: pd.DataFrame,
1052 ) -> str:
1053 if len(text_s) > 0 and annot2mask is not None and annot_df is not None:
1054 cols = set(annot_df.columns).intersection(annot2mask.keys())
1055 if len(cols) > 0:
1056 for col in cols:
1057 text_s = self._substitute(
1058 text_s,
1059 col,
1060 annot2mask[col],
1061 annot_df,
1062 )
1063 return text_s
1065 def _substitute(
1066 self,
1067 text_s: pd.Series,
1068 col: str,
1069 repl_mask: str,
1070 annot_df: pd.DataFrame,
1071 ) -> str:
1072 """Substitute the "mask" char for "text" chars at "col"-annotated positions.
1074 Args:
1075 text_s: The text series to revise.
1076 col: The annotation col identifying positions to mask.
1077 repl_mask: The mask character to inject at annotated positions.
1078 annot_df: The annotations dataframe.
1080 Returns:
1081 The masked text.
1082 """
1083 annot_mask = self._get_annot_mask(annot_df, len(text_s), col)
1084 text_s = text_s.mask(annot_mask, repl_mask)
1085 return text_s
1087 def add_annotations(self, annotations: Annotations):
1088 """Add the annotations to this instance.
1090 Args:
1091 annotations: The annotations to add.
1092 """
1093 if annotations is not None and not annotations.is_empty():
1094 df = annotations.df
1095 if self._annots is None:
1096 self._annots = annotations
1097 elif self._annots.is_empty():
1098 if df is not None:
1099 self._annots.set_df(df.copy())
1100 elif df is not None:
1101 self._annots.add_df(df)
1104class Annotator(ABC):
1105 """Class for annotating text"""
1107 def __init__(
1108 self,
1109 name: str,
1110 ):
1111 """Initialize Annotator.
1113 Args:
1114 name: The name of this annotator.
1115 """
1116 self.name = name
1118 @abstractmethod
1119 def annotate_input(
1120 self,
1121 text_obj: AnnotatedText,
1122 **kwargs: Any
1123 ) -> Annotations:
1124 """Annotate this instance's text, additively updating its annotations.
1126 Args:
1127 text_obj: The text object to annotate.
1128 **kwargs: Additional keyword arguments.
1130 Returns:
1131 The annotations added.
1132 """
1133 raise NotImplementedError
1136class BasicAnnotator(Annotator):
1137 """Class for extracting basic (possibly multi -level or -part) entities."""
1139 def annotate_input(
1140 self,
1141 text_obj: AnnotatedText,
1142 **kwargs: Any
1143 ) -> Annotations:
1144 """Annotate the text obj, additively updating the annotations.
1146 Args:
1147 text_obj: The text to annotate.
1148 **kwargs: Additional keyword arguments.
1150 Returns:
1151 The annotations added to the text.
1152 """
1153 # Get new annotation with just the syntax
1154 annots = self.annotate_text(text_obj.text)
1156 # Add syntactic annotations only as a bookmark
1157 text_obj.annotations.add_df(annots.as_df)
1159 return annots
1161 @abstractmethod
1162 def annotate_text(self, text_str: str) -> Annotations:
1163 """Build annotations for the text string.
1165 Args:
1166 text_str: The text string to annotate.
1168 Returns:
1169 Annotations for the text.
1170 """
1171 raise NotImplementedError
1174# TODO: remove this if unused -- stanza_annotator isa Authority -vs- stanza_annotator isa SyntacticParser
1175class SyntacticParser(BasicAnnotator):
1176 """Class for creating syntactic annotations for an input."""
1178 def annotate_input(
1179 self,
1180 text_obj: AnnotatedText,
1181 **kwargs: Any
1182 ) -> Annotations:
1183 """Annotate the text, additively updating the annotations.
1185 Args:
1186 text_obj: The text to annotate.
1187 **kwargs: Additional keyword arguments.
1189 Returns:
1190 The annotations added to the text.
1191 """
1192 # Get new annotation with just the syntax
1193 annots = self.annotate_text(text_obj.text)
1195 # Add syntactic annotations only as a bookmark
1196 text_obj.bookmarks[self.name] = annots.as_df
1198 return annots
1201class EntityAnnotator(BasicAnnotator):
1202 """Class for extracting single (possibly multi-level or -part) entities."""
1204 def __init__(
1205 self,
1206 name: str,
1207 mask_char: str = " ",
1208 ):
1209 """Initialize EntityAnnotator.
1211 Args:
1212 name: The name of this annotator.
1213 mask_char: The character to use to mask out previously annotated
1214 spans of this annotator's text.
1215 """
1216 super().__init__(name)
1217 self.mask_char = mask_char
1219 @property
1220 @abstractmethod
1221 def annotation_cols(self) -> Set[str]:
1222 """Report the (final group or record) annotation columns that are filled
1223 by this annotator when its entities are annotated.
1224 """
1225 raise NotImplementedError
1227 @abstractmethod
1228 def mark_records(self, annotations: Annotations, largest_only: bool = True):
1229 """Collect and mark annotation records.
1231 Args:
1232 annotations: The annotations.
1233 largest_only: True to only mark (keep) the largest records.
1234 """
1235 raise NotImplementedError
1237 @abstractmethod
1238 def validate_records(
1239 self,
1240 annotations: Annotations,
1241 ):
1242 """Validate annotated records.
1244 Args:
1245 annotations: The annotations.
1246 """
1247 raise NotImplementedError
1249 @abstractmethod
1250 def compose_groups(self, annotations: Annotations) -> Annotations:
1251 """Compose annotation rows into groups.
1253 Args:
1254 annotations: The annotations.
1256 Returns:
1257 The composed annotations.
1258 """
1259 raise NotImplementedError
1261 def annotate_input(
1262 self,
1263 text_obj: AnnotatedText,
1264 annot_mask_cols: Set[str] = None,
1265 merge_strategies: Dict[str, MergeStrategy] = None,
1266 largest_only: bool = True,
1267 **kwargs: Any
1268 ) -> Annotations:
1269 """Annotate the text object (optionally) after masking out previously
1270 annotated spans, additively updating the annotations in the text
1271 object.
1273 Args:
1274 text_obj: The text object to annotate.
1275 annot_mask_cols: The (possible) previous annotations whose
1276 spans to ignore in the text.
1277 merge_strategies: A dictionary of each input annotation bookmark
1278 tag mapped to a merge strategy for merging this annotator's
1279 annotations with the bookmarked dataframe. This is useful, for
1280 example, when merging syntactic information to refine ambiguities.
1281 largest_only: True to only mark largest records.
1282 **kwargs: Additional keyword arguments.
1284 Returns:
1285 The annotations added to the text object.
1286 """
1287 # TODO: Use annot_mask_cols to mask annotations
1288 # annot2mask = (
1289 # None
1290 # if annot_mask_cols is None
1291 # else {
1292 # col: self.mask_char for col in annot_mask_cols
1293 # }
1294 # )
1296 annots = self.annotate_text(text_obj.text)
1297 if annots is None:
1298 return annots
1300 if merge_strategies is not None:
1301 bookmarks = text_obj.bookmarks
1302 if bookmarks is not None and len(bookmarks) > 0:
1303 for tag, merge_strategy in merge_strategies.items():
1304 if tag in bookmarks:
1305 text_obj.bookmarks[f"{self.name}.pre-merge:{tag}"] = annots.df
1306 annots.add_df(bookmarks[tag])
1307 annots = merge(annots, merge_strategy)
1309 annots = self.compose_groups(annots)
1311 self.mark_records(annots, largest_only=largest_only)
1312 # NOTE: don't pass "text" here because it may be masked
1313 self.validate_records(annots)
1314 text_obj.annotations.add_df(annots.df)
1315 return annots
1317 @property
1318 @abstractmethod
1319 def highlight_fieldstyles(self) -> Dict[str, Dict[str, Dict[str, str]]]:
1320 """Get highlight field styles for this annotator's annotations of the form:
1321 {
1322 <field_col>: {
1323 <field_value>: {
1324 <css-attr>: <css-value>
1325 }
1326 }
1327 }
1328 For css-attr's like 'background-color', 'foreground-color', etc.
1329 """
1330 raise NotImplementedError
1333class HtmlHighlighter:
1334 """Helper class to add HTML markup for highlighting spans of text."""
1336 def __init__(
1337 self,
1338 field2style: Dict[str, Dict[str, str]],
1339 tooltip_class: str = "tooltip",
1340 tooltiptext_class: str = "tooltiptext",
1341 ):
1342 """Initialize HtmlHighlighter.
1344 Args:
1345 field2style: The annotation column to highlight with its
1346 associated style, for example:
1347 {
1348 'car_model_field': {
1349 'year': {'background-color': 'lightyellow'},
1350 'make': {'background-color': 'lightgreen'},
1351 'model': {'background-color': 'cyan'},
1352 'style': {'background-color': 'magenta'},
1353 },
1354 }
1355 tooltip_class: The css tooltip class.
1356 tooltiptext_class: The css tooltiptext class.
1357 """
1358 self.field2style = field2style
1359 self.tooltip_class = tooltip_class
1360 self.tooltiptext_class = tooltiptext_class
1362 def highlight(
1363 self,
1364 text_obj: AnnotatedText,
1365 ) -> str:
1366 """Return an html string with the given fields (annotation columns)
1367 highlighted with the associated styles.
1369 Args:
1370 text_obj: The annotated text to markup.
1372 Returns:
1373 HTML string with highlighted annotations.
1374 """
1375 result = ["<p>"]
1376 anns = text_obj.annotations
1377 an_df = anns.df
1378 for field, styles in self.field2style.items():
1379 # NOTE: the following line relies on an_df already being sorted
1380 df = an_df[an_df[field].isin(styles)]
1381 cur_pos = 0
1382 for _loc, row in df.iterrows():
1383 enttype = row[field]
1384 style = styles[enttype]
1385 style_str = " ".join([f"{key}: {value};" for key, value in style.items()])
1386 start_pos = row[anns.metadata.start_pos_col]
1387 if start_pos > cur_pos:
1388 result.append(text_obj.text[cur_pos:start_pos])
1389 end_pos = row[anns.metadata.end_pos_col]
1390 result.append(f'<mark class="{self.tooltip_class}" style="{style_str}">')
1391 result.append(text_obj.text[start_pos:end_pos])
1392 result.append(f'<span class="{self.tooltiptext_class}">{enttype}</span>')
1393 result.append("</mark>")
1394 cur_pos = end_pos
1395 result.append("</p>")
1396 return "\n".join(result)
1399class AnnotatorKernel(ABC):
1400 """Class for encapsulating core annotation logic for multiple annotators"""
1402 @property
1403 @abstractmethod
1404 def annotators(self) -> List[EntityAnnotator]:
1405 """Get the entity annotators"""
1406 raise NotImplementedError
1408 @abstractmethod
1409 def annotate_input(self, text_obj: AnnotatedText) -> Annotations:
1410 """Execute all annotations on the text_obj"""
1411 raise NotImplementedError
1414class CompoundAnnotator(Annotator):
1415 """Class to apply a series of annotators through an AnnotatorKernel"""
1417 def __init__(
1418 self,
1419 kernel: AnnotatorKernel,
1420 name: str = "entity",
1421 ):
1422 """Initialize with the annotators and this extractor's name.
1424 Args:
1425 kernel: The annotations kernel to use.
1426 name: The name of this information extractor to be the
1427 annotations base column name for <name>_num and <name>_recsnum.
1428 """
1429 super().__init__(name=name)
1430 self.kernel = kernel
1432 def annotate_input(
1433 self,
1434 text_obj: AnnotatedText,
1435 reset: bool = True,
1436 **kwargs: Any
1437 ) -> Annotations:
1438 """Annotate the text.
1440 Args:
1441 text_obj: The AnnotatedText object to annotate.
1442 reset: When True, reset and rebuild any existing annotations.
1443 **kwargs: Additional keyword arguments.
1445 Returns:
1446 The annotations added to the text_obj.
1447 """
1448 if reset:
1449 text_obj.annotations.clear()
1450 annots = self.kernel.annotate_input(text_obj)
1451 return annots
1453 def get_html_highlighted_text(
1454 self,
1455 text_obj: AnnotatedText,
1456 annotator_names: List[str] = None,
1457 ) -> str:
1458 """Get html-hilighted text for the identified input's annotations
1459 from the given annotators (or all).
1461 Args:
1462 text_obj: The input text to highlight.
1463 annotator_names: The subset of annotators to highlight.
1465 Returns:
1466 HTML string with highlighted text.
1467 """
1468 if annotator_names is None:
1469 annotator_names = [ann.name for ann in self.kernel.annotators]
1470 hfs = {
1471 ann.name: ann.highlight_fieldstyles
1472 for ann in self.kernel.annotators
1473 if ann.name in annotator_names
1474 }
1475 hh = HtmlHighlighter(hfs)
1476 return hh.highlight(text_obj)