"""This module contains some helper functions for the models package"""
import enum
import json
import logging
import os
import re
from tempfile import mkstemp
import numpy as np

import nltk
from sklearn.metrics import make_scorer

from ..gazetteer import Gazetteer
from ..text_preparation.text_preparation_pipeline import TextPreparationPipelineFactory

logger = logging.getLogger(__name__)


# Example types

# Label types

# resource/requirements names
GAZETTEER_RSC = "gazetteers"
QUERY_FREQ_RSC = "q_freq"
SYS_TYPES_RSC = "sys_types"
ENABLE_STEMMING = "enable-stemming"
WORD_FREQ_RSC = "w_freq"
WORD_NGRAM_FREQ_RSC = "w_ngram_freq"
CHAR_NGRAM_FREQ_RSC = "c_ngram_freq"
SENTIMENT_ANALYZER = "vader_classifier"

[docs]class ModelType(enum.Enum): TEXT_MODEL = "text" TAGGER_MODEL = "tagger"
[docs]def create_model(config): """Creates a model instance using the provided configuration Args: config (ModelConfig): A model configuration Returns: Model: a configured model Raises: ValueError: When model configuration is invalid """ try: # TODO: deprecate MODEL_MAP and use ModelFactory instead (be aware of cyclic imports) return MODEL_MAP["auto"].create_model_from_config(config) except KeyError as e: msg = "Invalid model configuration: Unknown model type {!r}" raise ValueError(msg.format(config.model_type)) from e
[docs]def load_model(path): """Loads a model from a specified path Args: path (str): A path where the model configuration is pickled along with other metadata Returns: dict: metadata loaded from the path, which contains the configured model in 'model' key and the model configs in 'model_config' key along with other keys Raises: ValueError: When model configuration is invalid """ # TODO: deprecate MODEL_MAP and use ModelFactory instead (be aware of cyclic imports) return MODEL_MAP["auto"].create_model_from_path(path)
[docs]def create_annotator(config): """Creates an annotator instance using the provided configuration Args: config (dict): A model configuration Returns: Annotator: An Annotator class Raises: ValueError: When model configuration is invalid or required key is missing """ if "annotator_class" not in config: raise KeyError( "Missing required argument in AUTO_ANNOTATOR_CONFIG: 'annotator_class'" ) if config["annotator_class"] in ANNOTATOR_MAP: return ANNOTATOR_MAP[config.pop("annotator_class")](**config) else: msg = "Invalid model configuration: Unknown model type {!r}" raise KeyError(msg.format(config["annotator_class"]))
[docs]def get_feature_extractor(example_type, name): """Gets a feature extractor given the example type and name Args: example_type (str): The type of example name (str): The name of the feature extractor Returns: function: A feature extractor wrapper """ return FEATURE_MAP[example_type][name]
[docs]def get_label_encoder(config): """Gets a label encoder given the label type from the config Args: config (ModelConfig): A model configuration Returns: LabelEncoder: The appropriate LabelEncoder object for the given config """ return LABEL_MAP[config.label_type](config)
[docs]def create_embedder_model(app_path, config): """Creates and loads an embedder model Args: config (dict): Model settings passed in as a dictionary with 'embedder_type' being a required key Returns: Embedder: An instance of appropriate embedder class Raises: ValueError: When model configuration is invalid or required key is missing """ if "model_settings" in config and config["model_settings"]: # when config = {"model_settings": {"embedder_type": ..., "..": ...}} embedder_config = config["model_settings"] else: # when config = {"embedder_type": ..., "..": ...}} embedder_config = config embedder_type = embedder_config.get("embedder_type") if not embedder_type: raise KeyError( "Missing required argument in config supplied to create embedder model: 'embedder_type'" ) try: # cache_path for embedder, if required, needs to be included as a key in the embedder_config return EMBEDDER_MAP[embedder_type](app_path=app_path, **embedder_config) except KeyError as e: msg = "Invalid model configuration: Unknown embedder type {!r}" raise ValueError(msg.format(embedder_type)) from e
[docs]def register_model(model_type, model_class): """Registers a model for use with `create_model()` Args: model_type (str): The model type as specified in model configs model_class (class): The model to register """ # TODO: deprecate MODEL_MAP var in in lieu of ModelFactory MODEL_MAP[model_type] = model_class
[docs]def register_query_feature(feature_name): """Registers query feature Args: feature_name (str): The name of the query feature Returns: (func): the feature extractor """ return register_feature(QUERY_EXAMPLE_TYPE, feature_name=feature_name)
[docs]def register_entity_feature(feature_name): """Registers entity feature Args: feature_name (str): The name of the entity feature Returns: (func): the feature extractor """ return register_feature(ENTITY_EXAMPLE_TYPE, feature_name=feature_name)
[docs]def register_annotator(annotator_class_name, annotator_class): """Registers an Annotator class for use with `create_annotator()` Args: annotator_class_name (str): The annotator class name as specified in the config model_class (class): The annotator class to register """ ANNOTATOR_MAP[annotator_class_name] = annotator_class
[docs]def register_augmentor(augmentor_name, augmentor_class): """Registers an Annotator class for use with `create_annotator()` Args: annotator_class_name (str): The annotator class name as specified in the config model_class (class): The annotator class to register """ AUGMENTATION_MAP[augmentor_name] = augmentor_class
[docs]def register_feature(feature_type, feature_name): """ Decorator for adding feature extractor mappings to FEATURE_MAP Args: feature_type: 'query' or 'entity' feature_name: The name of the feature, used in Returns: (func): the feature extractor """ def add_feature(func): if feature_type not in {QUERY_EXAMPLE_TYPE, ENTITY_EXAMPLE_TYPE}: raise TypeError("Feature type can only be 'query' or 'entity'") # Add func to feature map with given type and name if feature_type in FEATURE_MAP: FEATURE_MAP[feature_type][feature_name] = func else: FEATURE_MAP[feature_type] = {feature_name: func} return func return add_feature
[docs]def register_label(label_type, label_encoder): """Register a label encoder for use with `get_label_encoder()` Args: label_type (str): The label type of the label encoder label_encoder (LabelEncoder): The label encoder class to register Raises: ValueError: If the label type is already registered """ if label_type in LABEL_MAP: msg = "Label encoder for label type {!r} is already registered.".format( label_type ) raise ValueError(msg) LABEL_MAP[label_type] = label_encoder
[docs]def register_embedder(embedder_type, embedder): if embedder_type in EMBEDDER_MAP: msg = "Embedder of type {!r} is already registered.".format(embedder_type) raise ValueError(msg) EMBEDDER_MAP[embedder_type] = embedder
[docs]def mask_numerics(token): """Masks digit characters in a token Args: token (str): A string Returns: str: A masked string for digit characters """ if token.isdigit(): return "#NUM" else: return re.sub(r"\d", "8", token)
[docs]def get_ngram(tokens, start, length): """Gets a ngram from a list of tokens. Handles out-of-bounds token positions with a special character. Args: tokens (list of str): Word tokens. start (int): The index of the desired ngram's start position. length (int): The length of the n-gram, e.g. 1 for unigram, etc. Returns: (str) An n-gram in the input token list. """ ngram_tokens = [] for index in range(start, start + length): token = ( OUT_OF_BOUNDS_TOKEN if index < 0 or index >= len(tokens) else tokens[index] ) ngram_tokens.append(token) return " ".join(ngram_tokens)
[docs]def get_ngrams_upto_n(tokens, n): """This function returns a generator that returns ngram tuples with length upto n Args: tokens (list of str): Word tokens. n (int): The length of n-gram upto which the ngram tokens are generated Returns: tuple: ngram, (token index start, token index end) """ if n == 0: return [] for length, i in enumerate(range(1, n + 1)): for idx, j in enumerate(nltk.ngrams(tokens, i)): yield j, (idx, idx + length)
[docs]def get_seq_accuracy_scorer(): """ Returns a scorer that can be used by sklearn's GridSearchCV based on the sequence_accuracy_scoring method below. """ return make_scorer(score_func=sequence_accuracy_scoring)
[docs]def get_seq_tag_accuracy_scorer(): """ Returns a scorer that can be used by sklearn's GridSearchCV based on the sequence_tag_accuracy_scoring method below. """ return make_scorer(score_func=sequence_tag_accuracy_scoring)
[docs]def sequence_accuracy_scoring(y_true, y_pred): """Accuracy score which calculates two sequences to be equal only if all of their predicted tags are equal. Args: y_true (list): A sequence of true expected labels y_pred (list): A sequence of predicted labels Returns: float: The sequence-level accuracy when comparing the predicted labels \ against the true expected labels """ total = len(y_true) if not total: return 0 matches = sum( 1 for yseq_true, yseq_pred in zip(y_true, y_pred) if yseq_true == yseq_pred ) return float(matches) / float(total)
[docs]def sequence_tag_accuracy_scoring(y_true, y_pred): """Accuracy score which calculates the number of tags that were predicted correctly. Args: y_true (list): A sequence of true expected labels y_pred (list): A sequence of predicted labels Returns: float: The tag-level accuracy when comparing the predicted labels \ against the true expected labels """ y_true_flat = [tag for seq in y_true for tag in seq] y_pred_flat = [tag for seq in y_pred for tag in seq] total = len(y_true_flat) if not total: return 0 matches = sum( 1 for (y_true_tag, y_pred_tag) in zip(y_true_flat, y_pred_flat) if y_true_tag == y_pred_tag ) return float(matches) / float(total)
[docs]def entity_seqs_equal(expected, predicted): """ Returns true if the expected entities and predicted entities all match, returns false otherwise. Note that for entity comparison, we compare that the span, text, and type of all the entities match. Args: expected (list of core.Entity): A list of the expected entities for some query predicted (list of core.Entity): A list of the predicted entities for some query """ if len(expected) != len(predicted): return False for expected_entity, predicted_entity in zip(expected, predicted): if expected_entity.entity.type != predicted_entity.entity.type: return False if expected_entity.span != predicted_entity.span: return False if expected_entity.text != predicted_entity.text: return False return True
[docs]def merge_gazetteer_resource(resource, dynamic_resource, text_preparation_pipeline): """ Returns a new resource that is a merge between the original resource and the dynamic resource passed in for only the gazetteer values Args: resource (dict): The original resource built from the app dynamic_resource (dict): The dynamic resource passed in text_preparation_pipeline (TextPreparationPipeline): For text tokenization and normalization Returns: dict: The merged resource """ return_obj = {} for key in resource: # Pass by reference if not a gazetteer key if key != GAZETTEER_RSC: return_obj[key] = resource[key] continue # Create a dict from scratch if we match the gazetteer key return_obj[key] = {} for entity_type in resource[key]: # If the entity type is in the dyn gaz, we merge the data. Else, # just pass by reference the original resource data if entity_type in dynamic_resource[key]: new_gaz = Gazetteer(entity_type, text_preparation_pipeline) # We deep copy here since shallow copying will also change the # original resource's data during the '_update_entity' op. new_gaz.from_dict(resource[key][entity_type]) for entity in dynamic_resource[key][entity_type]: new_gaz._update_entity( text_preparation_pipeline.normalize(entity), dynamic_resource[key][entity_type][entity], ) # The new gaz created is a deep copied version of the merged gaz data return_obj[key][entity_type] = new_gaz.to_dict() else: return_obj[key][entity_type] = resource[key][entity_type] return return_obj
[docs]def ingest_dynamic_gazetteer(resource, dynamic_resource=None, text_preparation_pipeline=None): """Ingests dynamic gazetteers from the app and adds them to the resource Args: resource (dict): The original resource dynamic_resource (dict, optional): The dynamic resource that needs to be ingested text_preparation_pipeline (TextPreparationPipeline): For text tokenization and normalization Returns: (dict): A new resource with the ingested dynamic resource """ if not dynamic_resource or GAZETTEER_RSC not in dynamic_resource: return resource text_preparation_pipeline = ( text_preparation_pipeline or TextPreparationPipelineFactory.create_default_text_preparation_pipeline() ) workspace_resource = merge_gazetteer_resource( resource, dynamic_resource, text_preparation_pipeline ) return workspace_resource
[docs]def requires(resource): """ Decorator to enforce the resource dependencies of the active feature extractors Args: resource (str): the key of a classifier resource which must be initialized before the given feature extractor is used Returns: (func): the feature extractor """ def add_resource(func): req = func.__dict__.get("requirements", set()) req.add(resource) func.requirements = req return func return add_resource
[docs]def np_encoder(val): if isinstance(val, np.generic): return val.item() raise TypeError(f"{type(val)} cannot be serialized by JSON.")
[docs]class FileBackedList: """ FileBackedList implements an interface for simple list use cases that is backed by a temporary file on disk. This is useful for simple list processing in a memory efficient way. """ def __init__(self): self.num_lines = 0 self.file_handle = None fd, self.filename = mkstemp() os.close(fd) def __len__(self): return self.num_lines
[docs] def append(self, line): if self.file_handle is None: self.file_handle = open(self.filename, "w") self.file_handle.write(json.dumps(line, default=np_encoder)) self.file_handle.write("\n") self.num_lines += 1
def __del__(self): if self.file_handle: self.file_handle.close() os.unlink(self.filename) def __iter__(self): # Flush out any remaining data to be written if self.file_handle: self.file_handle.close() self.file_handle = None return FileBackedList.Iterator(self)
[docs] class Iterator: def __init__(self, source): self.source = source self.file_handle = open(source.filename, "r") def __len__(self): return len(self.source) def __next__(self): try: line = next(self.file_handle) return json.loads(line) except Exception as e: self.file_handle.close() self.file_handle = None if not isinstance(e, StopIteration): logger.error("Error reading from FileBackedList") raise def __del__(self): if self.file_handle: self.file_handle.close()