Source code for mindmeld.gazetteer

# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import codecs
import logging
import os
from collections import defaultdict

import joblib

logger = logging.getLogger(__name__)


[docs]class Gazetteer: """ This class holds the following fields, which are extracted and exported to file. Attributes: entity_count (int): Total entities in the file pop_dict (dict): A dictionary containing the entity name as a key and the popularity score as the value. If there are more than one entity with the same name, the popularity is the maximum value across all duplicate entities. index (dict): A dictionary containing the inverted index, which maps terms and n-grams to the set of documents which contain them entities (list): A list of all entities sys_types (set): The set of nested numeric types for this entity """ def __init__(self, name, text_preparation_pipeline, exclude_ngrams=False): """ Args: domain (str): The domain that this gazetteer is used text_preparation_pipeline (TextPreparationPipeline): Pipeline for tokenization and normalization of text. exclude_ngrams (bool): The boolean flat whether to exclude ngrams """ self.name = name self.exclude_ngrams = exclude_ngrams self.max_ngram = 1 self.entity_count = 0 self.pop_dict = defaultdict(int) self.index = defaultdict(set) self.entities = [] self.sys_types = set() self.text_preparation_pipeline = text_preparation_pipeline
[docs] def to_dict(self): """ Returns: dict """ return { "name": self.name, "total_entities": self.entity_count, "pop_dict": self.pop_dict, "index": self.index, "entities": self.entities, "sys_types": self.sys_types, }
[docs] def from_dict(self, serialized_gaz): """De-serializes gaz object from a dictionary using deep copy ops Args: serialized_gaz (dict): The serialized gaz object """ for key, value in serialized_gaz.items(): # We only shallow copy lists and dicts here since we do not have nested # data structures in this container, only 1-levels dictionaries and lists, # so the references only need to be copies. For all other types, like strings, # they can just be passed by value. setattr( self, key, value.copy() if isinstance(value, (list, dict)) else value )
[docs] def dump(self, gaz_path): """Persists the gazetteer to disk. Args: gaz_path (str): The location on disk where the gazetteer should be stored """ # make directory if necessary folder = os.path.dirname(gaz_path) if not os.path.isdir(folder): os.makedirs(folder) joblib.dump(self.to_dict(), gaz_path)
[docs] def load(self, gaz_path): """Loads the gazetteer from disk Args: gaz_path (str): The location on disk where the gazetteer is stored """ gaz_data = joblib.load(gaz_path) self.name = gaz_data["name"] self.entity_count = gaz_data["total_entities"] self.pop_dict = gaz_data["pop_dict"] self.index = gaz_data["index"] self.entities = gaz_data["entities"] self.sys_types = gaz_data["sys_types"]
def _update_entity(self, entity, popularity, keep_max=True): """ Updates all gazetteer data with an entity and its popularity. Args: entity (str): A normalized entity name. popularity (float): The entity's popularity value. keep_max (bool): If True, if the entity is already in the pop_dict, then set the popularity to the max of popularity and the value in the pop_dict. Otherwise, overwrite it. """ # Only update the relevant data structures when the entity isn't # already in the gazetteer. Update the popularity either way. tokenized_gaz_entry = tuple( token["entity"] for token in self.text_preparation_pipeline.tokenize_and_normalize( entity ) ) if self.pop_dict[tokenized_gaz_entry] == 0: self.entities.append(entity) if not self.exclude_ngrams: for ngram in iterate_ngrams(entity.split(), max_length=self.max_ngram): self.index[ngram].add(self.entity_count) self.entity_count += 1 if keep_max: old_value = self.pop_dict[tokenized_gaz_entry] self.pop_dict[tokenized_gaz_entry] = max(self.pop_dict[tokenized_gaz_entry], popularity) if self.pop_dict[tokenized_gaz_entry] != old_value: logger.debug( "Updating gazetteer value of entity %s from %s to %s", entity, old_value, self.pop_dict[tokenized_gaz_entry], ) else: self.pop_dict[tokenized_gaz_entry] = popularity
[docs] def update_with_entity_data_file(self, filename, popularity_cutoff, normalizer): """ Updates this gazetteer with data from an entity data file. Args: filename (str): The filename of the entity data file. popularity_cutoff (float): A threshold at which entities with popularity below this value are ignored. normalizer (function): A function that normalizes text. """ logger.info("Loading entity data from '%s'", filename) line_count = 0 entities_added = 0 num_cols = None if not os.path.isfile(filename): logger.warning("Entity data file was not found at %s", filename) else: with codecs.open(filename, encoding="utf8") as data_file: for i, row in enumerate(data_file.readlines()): if not row: continue split_row = row.strip("\n").split("\t") if num_cols is None: num_cols = len(split_row) if len(split_row) != num_cols: msg = "Row {} of .tsv file '{}' malformed, expected {} columns" raise ValueError(msg.format(i + 1, filename, num_cols)) if num_cols == 2: pop, entity = split_row else: pop = 1.0 entity = split_row[0] pop = 0 if pop == "null" else float(pop) line_count += 1 entity = normalizer(entity) if pop > popularity_cutoff: self._update_entity(entity, float(pop)) entities_added += 1 logger.info( "%d/%d entities in entity data file exceeded popularity " "cutoff and were added to the gazetteer", entities_added, line_count, )
[docs] def update_with_entity_map( self, mapping, normalizer, update_if_missing_canonical=True ): """Update gazetteer with a list of normalized key,value pairs from the input mapping list Args: mapping (list): A list of dicts containing canonnical names and whitelists of a particular entity normalizer (func): A QueryFactory normalization function that is used to normalize the input mapping data before they are added to the gazetteer. """ logger.info("Loading synonyms from entity mapping") line_count = 0 synonyms_added = 0 missing_canonicals = 0 min_popularity = 0 if len(self.pop_dict) > 0: min_popularity = min(self.pop_dict.values()) for item in mapping: tokenized_canonical = tuple(normalizer(item["cname"]).split()) for syn in item["whitelist"]: line_count += 1 synonym = normalizer(syn) if update_if_missing_canonical or tokenized_canonical in self.pop_dict: self._update_entity( synonym, self.pop_dict.get(tokenized_canonical, min_popularity) ) synonyms_added += 1 if tokenized_canonical not in self.pop_dict: missing_canonicals += 1 logger.debug( "Synonym '%s' for entity '%s' not in gazetteer", synonym, str(tokenized_canonical), ) logger.info( "Added %d/%d synonyms from file into gazetteer", synonyms_added, line_count ) if update_if_missing_canonical and missing_canonicals: logger.info( "Loaded %d synonyms where the canonical name is not in the gazetteer", missing_canonicals, )
[docs]class NestedGazetteer: """ This class represents a gazetteer entry corresponding to a Query object """ def __init__(self, start_token_index, end_token_index_plus_one, gaz_name, token_ngram, raw_ngram): self._start_token_index = start_token_index self._end_token_index_plus_one = end_token_index_plus_one self._gaz_name = gaz_name self._token_ngram = token_ngram self._raw_ngram = raw_ngram @property def start_token_index(self): return self._start_token_index @property def end_token_index_plus_one(self): return self._end_token_index_plus_one @property def gaz_name(self): return self._gaz_name @property def token_ngram(self): return self._token_ngram @property def raw_ngram(self): return self._raw_ngram def __gt__(self, other_gaz): return self.start_token_index > other_gaz.start_token_index
[docs]def iterate_ngrams(tokens, min_length=1, max_length=1): """Iterates over all n-grams in a list of tokens. Args: tokens (list of str): A list of word tokens. min_length (int): The minimum length of n-gram to yield. max_length (int): The maximum length of n-gram to yield. Yields: (str) An n-gram from the input tokens list. """ max_length = min(len(tokens), max_length) unrolled_tokens = [tokens[i:] for i in range(max_length)] for length in range(min_length, max_length + 1): for ngram in zip(*unrolled_tokens[:length]): yield " ".join(ngram)