Coverage for src / kemi / entities.py: 98%
44 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-06-05 15:47 +0000
1"""Entity extraction for entity-aware retrieval.
3Provides pluggable entity linkers that extract normalized entity strings
4from text. The default :class:`RegexEntityLinker` uses regex heuristics for
5names, dates, emails, and URLs with zero external dependencies.
6"""
8from __future__ import annotations
10import re
11from abc import ABC, abstractmethod
12from typing import Any
15class EntityLinker(ABC):
16 """Abstract interface for extracting entities from text.
18 Implementations should return a set of **normalized** entity strings
19 (e.g. lower-cased, stripped) so that overlap comparisons are case-insensitive.
20 """
22 @abstractmethod
23 def extract(self, text: str) -> set[str]:
24 """Extract entities from *text*.
26 Args:
27 text: Input string.
29 Returns:
30 Set of normalized entity strings.
31 """
32 pass
35class NoopEntityLinker(EntityLinker):
36 """No-op entity linker that returns an empty set.
38 Used when entity-aware retrieval is disabled.
39 """
41 def extract(self, text: str) -> set[str]:
42 return set()
45class RegexEntityLinker(EntityLinker):
46 """Regex-based entity linker.
48 Extracts:
49 - Capitalized phrases (names, places, organisations)
50 - Email addresses
51 - URLs
52 - ISO-style dates (YYYY-MM-DD) and relaxed dates (Month DD, YYYY)
54 All entities are normalised to lower-case.
55 """
57 _DATE_PATTERNS = [
58 re.compile(r"\b\d{4}-\d{2}-\d{2}\b"), # 2024-06-05
59 re.compile(r"\b\d{1,2}/\d{1,2}/\d{2,4}\b"), # 06/05/2024
60 re.compile(r"\b(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},?\s+\d{4}\b", re.IGNORECASE),
61 ]
63 _EMAIL_PATTERN = re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b")
64 _URL_PATTERN = re.compile(r"https?://[^\s]+|www\.[^\s]+")
65 _NAME_PATTERN = re.compile(r"\b[A-Z][a-zA-Z]+(?:\s+[A-Z][a-zA-Z]+)*\b")
67 def extract(self, text: str) -> set[str]:
68 entities: set[str] = set()
70 # Names / capitalised phrases
71 for match in self._NAME_PATTERN.finditer(text):
72 entities.add(match.group().lower())
74 # Dates
75 for pattern in self._DATE_PATTERNS:
76 for match in pattern.finditer(text):
77 entities.add(match.group().lower())
79 # Emails
80 for match in self._EMAIL_PATTERN.finditer(text):
81 entities.add(match.group().lower())
83 # URLs
84 for match in self._URL_PATTERN.finditer(text):
85 entities.add(match.group().lower())
87 return entities
90class SpacyEntityLinker(EntityLinker):
91 """spaCy NER-based entity linker.
93 Uses spaCy’s named-entity recognition pipeline for accurate extraction
94 of people, organisations, locations, dates, products, etc.
96 Requires ``spacy`` and a language model (e.g. ``en_core_web_sm``) to be
97 installed:
99 .. code-block:: bash
101 pip install spacy
102 python -m spacy download en_core_web_sm
104 All extracted entities are normalised to lower-case.
106 Args:
107 model: spaCy model name (default ``en_core_web_sm``).
108 allowed_labels: Set of spaCy NER labels to keep. If ``None``,
109 a sensible default set is used.
110 See https://spacy.io/usage/linguistic-features#named-entities
111 """
113 _DEFAULT_LABELS: set[str] = {
114 "PERSON",
115 "ORG",
116 "GPE",
117 "LOC",
118 "DATE",
119 "EVENT",
120 "PRODUCT",
121 "WORK_OF_ART",
122 "LAW",
123 "LANGUAGE",
124 "FAC",
125 "NORP",
126 }
128 def __init__(
129 self,
130 model: str = "en_core_web_sm",
131 allowed_labels: set[str] | None = None,
132 ) -> None:
133 try:
134 import spacy
135 except ImportError as exc:
136 raise ImportError(
137 "spaCy is required for SpacyEntityLinker. "
138 "Install with: pip install spacy && python -m spacy download en_core_web_sm"
139 ) from exc
141 self._nlp: Any = spacy.load(model)
142 self._allowed_labels: set[str] = allowed_labels if allowed_labels is not None else self._DEFAULT_LABELS
144 def extract(self, text: str) -> set[str]:
145 doc = self._nlp(text)
146 entities: set[str] = set()
147 for ent in doc.ents:
148 if ent.label_ in self._allowed_labels:
149 entities.add(ent.text.lower())
150 return entities