Source code for zink.extractor

try:
    from gliner import GLiNER
except ImportError:
    GLiNER = None
import warnings
import concurrent.futures
warnings.filterwarnings("ignore")


[docs] class EntityExtractor: def __init__( self, model_name="deepanwa/NuNerZero_onnx" ): # previous model - numind/NuNerZero if GLiNER is None: self.model = None else: self.model = GLiNER.from_pretrained( model_name, load_onnx_model=True, load_tokenizer=True ) self.model_name = model_name # NuZero requires lower-cased labels. self.labels = ["person", "date", "location"] def _process_chunk_thread_safe(self, chunk): """ This worker runs in a thread and accesses the class's single model instance. """ # --- Pass 1: Normal threshold on the original text --- pass1_entities = self.model.predict_entities(text=self.thread_text, labels=chunk, threshold=0.5) # --- Temporary Masking for Pass 2 --- temp_mutable_text = list(self.thread_text) for entity in pass1_entities: for i in range(entity['start'], entity['end']): temp_mutable_text[i] = ' ' masked_text_for_pass2 = "".join(temp_mutable_text) # --- Pass 2: High threshold on the temporarily masked text --- pass2_entities = self.model.predict_entities(text=masked_text_for_pass2, labels=chunk, threshold=0.9) return pass1_entities + pass2_entities
[docs] def predict(self, text, labels=None, max_passes=2): """ Iteratively finds entities by masking found entities and re-running the model. Parameters: text (str): The input text. labels (list of str, optional): Entity labels to predict. Defaults to None. max_passes (int): A safeguard to prevent potential infinite loops. Returns: list of dict: A list of all unique entities found across all passes. """ if labels is not None: predict_labels = [label.lower() for label in labels] else: predict_labels = self.labels all_entities = [] processed_spans = set() # Use a list of characters for easy replacement mutable_text_list = list(text) for _ in range(max_passes): current_text_to_process = "".join(mutable_text_list) if _ == 0: # 1. Call the model on the current version of the text newly_found_entities = self.model.predict_entities(current_text_to_process, predict_labels) else: # increased threshold for subsequent passes # This helps in focusing on more confident predictions after initial masking. # The threshold can be adjusted based on the model's performance. newly_found_entities = self.model.predict_entities(current_text_to_process, predict_labels, threshold=0.9) # If the model finds nothing, we can stop if not newly_found_entities: break # Filter out any entities we've already processed to avoid loops unique_new_entities = [] for ent in newly_found_entities: span = (ent['start'], ent['end']) if span not in processed_spans: unique_new_entities.append(ent) processed_spans.add(span) # If there were no *genuinely* new entities, stop if not unique_new_entities: break # 2. Add the unique new finds to our master list all_entities.extend(unique_new_entities) # 3. "Mask" the found entities by replacing them with spaces # This preserves the indices for the next pass. for entity in unique_new_entities: for i in range(entity['start'], entity['end']): mutable_text_list[i] = ' ' # Sort the final combined list by start position all_entities.sort(key=lambda x: x['start']) return all_entities
# def predict_thorough(self, text, labels=None, max_passes=2): # """ # Performs a more exhaustive, multi-pass entity extraction by processing labels in smaller batches. # This can improve accuracy when many different label types are specified but is more computationally intensive. # Parameters: # text (str): The input text. # labels (list of str, optional): Entity labels to predict. Defaults to the class's default labels. # max_passes (int): The maximum number of passes to run for each chunk of labels. # Returns: # list of dict: A list of all unique entities found across all passes and all label chunks. # """ # if labels is not None: # predict_labels = [label.lower() for label in labels] # else: # predict_labels = self.labels # if not predict_labels: # return [] # all_entities = [] # processed_spans = set() # mutable_text_list = list(text) # # Chunk the labels into groups of 3 for more focused prediction # label_chunk_size = 3 # label_chunks = [predict_labels[i:i + label_chunk_size] for i in range(0, len(predict_labels), label_chunk_size)] # for chunk in label_chunks: # # For each chunk of labels, run the multi-pass prediction logic # for pass_num in range(max_passes): # current_text_to_process = "".join(mutable_text_list) # threshold = 0.5 if pass_num == 0 else 0.9 # newly_found_entities = self.model.predict_entities(current_text_to_process, chunk, threshold=threshold) # if not newly_found_entities: # break # unique_new_entities = [] # for ent in newly_found_entities: # span = (ent['start'], ent['end']) # if span not in processed_spans: # unique_new_entities.append(ent) # processed_spans.add(span) # if not unique_new_entities: # break # all_entities.extend(unique_new_entities) # for entity in unique_new_entities: # for i in range(entity['start'], entity['end']): # mutable_text_list[i] = ' ' # all_entities.sort(key=lambda x: x['start']) # return all_entities
[docs] def predict_thorough(self, text, labels=None): """ Performs a highly detailed entity extraction using a hybrid approach. For each chunk of labels, it runs a two-pass process with internal, temporary masking to find a comprehensive set of entities. It then resolves overlaps between all found entities across all chunks by keeping the label with the highest confidence score for each unique text span. Parameters: text (str): The input text. labels (list of str, optional): Entity labels to predict. Returns: list of dict: A list of the highest-confidence entities for each unique span. """ if labels is not None: predict_labels = [label.lower() for label in labels] else: predict_labels = self.labels if not predict_labels: return [] # Final dictionary to hold the best entity for each span across all chunks best_entities_by_span = {} label_chunk_size = 3 label_chunks = [predict_labels[i:i + label_chunk_size] for i in range(0, len(predict_labels), label_chunk_size)] # Process one chunk of labels at a time for chunk in label_chunks: # --- Pass 1: Normal threshold on the original text --- pass1_entities = self.model.predict_entities(text, chunk, threshold=0.5) # --- Temporary Masking for Pass 2 --- temp_mutable_text = list(text) for entity in pass1_entities: for i in range(entity['start'], entity['end']): temp_mutable_text[i] = ' ' masked_text_for_pass2 = "".join(temp_mutable_text) # --- Pass 2: High threshold on the temporarily masked text --- pass2_entities = self.model.predict_entities(masked_text_for_pass2, chunk, threshold=0.9) # Combine all entities found just for this chunk entities_from_this_chunk = pass1_entities + pass2_entities # --- Conflict Resolution: Update the master dictionary --- for entity in entities_from_this_chunk: span = (entity['start'], entity['end']) existing_entity = best_entities_by_span.get(span) if not existing_entity or entity['score'] > existing_entity['score']: best_entities_by_span[span] = entity # Convert the dictionary of best entities back to a list and sort final_entities = list(best_entities_by_span.values()) final_entities.sort(key=lambda x: x['start']) return final_entities
[docs] def predict2(self, text, labels=None): """ Performs a highly detailed entity extraction using a MEMORY-EFFICIENT THREAD-BASED parallel approach. """ if labels is not None: predict_labels = [label.lower() for label in labels] else: predict_labels = self.labels if not predict_labels: return [] # Storing text on self to make it accessible to the thread worker method self.thread_text = text label_chunk_size = 2 label_chunks = [predict_labels[i:i + label_chunk_size] for i in range(0, len(predict_labels), label_chunk_size)] results_from_all_chunks = [] # Use a ThreadPoolExecutor, which shares memory. # You can control the number of threads with max_workers. with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: # The map function distributes the label chunks to threads in the pool. # Each thread will call self._process_chunk_thread_safe. results_from_all_chunks = list(executor.map(self._process_chunk_thread_safe, label_chunks)) # --- REDUCE PHASE (Unchanged) --- best_entities_by_span = {} all_found_entities = [entity for sublist in results_from_all_chunks for entity in sublist] for entity in all_found_entities: span = (entity['start'], entity['end']) existing_entity = best_entities_by_span.get(span) if not existing_entity or entity['score'] > existing_entity['score']: best_entities_by_span[span] = entity final_entities = list(best_entities_by_span.values()) final_entities.sort(key=lambda x: x['start']) # Clean up the temporary attribute del self.thread_text return final_entities
_DEFAULT_EXTRACTOR = EntityExtractor()