Source code for mindmeld.components.entity_resolver

# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module contains the entity resolver component of the MindMeld natural language processor.
"""
import copy
import hashlib
import json
import logging
import os
import pickle
import re
import uuid
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from string import punctuation

import numpy as np
import scipy
from sklearn.feature_extraction.text import TfidfVectorizer
from tqdm.auto import trange

from ._config import (
    get_app_namespace,
    get_classifier_config,
)
from ._util import _is_module_available, _get_module_or_attr as _getattr
from ..core import Entity
from ..exceptions import (
    ElasticsearchConnectionError,
    EntityResolverError
)
from ..models import create_embedder_model
from ..resource_loader import ResourceLoader, Hasher

if _is_module_available("elasticsearch"):
    from ._elasticsearch_helpers import (
        INDEX_TYPE_KB,
        INDEX_TYPE_SYNONYM,
        DOC_TYPE,
        DEFAULT_ES_SYNONYM_MAPPING,
        PHONETIC_ES_SYNONYM_MAPPING,
        create_es_client,
        delete_index,
        does_index_exist,
        get_field_names,
        get_scoped_index_name,
        load_index,
        resolve_es_config_for_version,
    )

logger = logging.getLogger(__name__)

DEFAULT_TOP_N = 20


[docs]class EntityResolverFactory: @staticmethod def _correct_deprecated_er_config(er_config): """ for backwards compatibility if `er_config` is supplied in deprecated format, its format is corrected and returned, else it is not modified and returned as-is deprecated usage >>> er_config = { "model_type": "text_relevance", "model_settings": { ... } } new usage >>> er_config = { "model_type": "resolver", "model_settings": { "resolver_type": "text_relevance" ... } } """ if not er_config.get("model_settings", {}).get("resolver_type"): model_type = er_config.get("model_type") if model_type == "resolver": raise ValueError( "Could not find `resolver_type` in `model_settings` of entity resolver") else: msg = "Using deprecated config format for Entity Resolver. " \ "See https://www.mindmeld.com/docs/userguide/entity_resolver.html " \ "for more details." warnings.warn(msg, DeprecationWarning) er_config = copy.deepcopy(er_config) model_settings = er_config.get("model_settings", {}) model_settings.update({"resolver_type": model_type}) er_config["model_settings"] = model_settings er_config["model_type"] = "resolver" return er_config @staticmethod def _validate_resolver_type(name): if name not in ENTITY_RESOLVER_MODEL_MAPPINGS: raise ValueError(f"Expected 'resolver_type' in config of Entity Resolver " f"among {[*ENTITY_RESOLVER_MODEL_MAPPINGS]} but found {name}") if name == "sbert_cosine_similarity" and not _is_module_available("sentence_transformers"): raise ImportError( "Must install the extra [bert] by running `pip install mindmeld[bert]` " "to use the built in embedder for entity resolution.") if name == "text_relevance" and not _is_module_available("elasticsearch"): raise ImportError( "Must install the extra [elasticsearch] by running " "`pip install mindmeld[elasticsearch]` " "to use Elasticsearch based entity resolution.")
[docs] @classmethod def create_resolver(cls, app_path, entity_type, config=None, resource_loader=None, **kwargs): """ Identifies appropriate entity resolver based on input config and returns it. Args: app_path (str): The application path. entity_type (str): The entity type associated with this entity resolver. resource_loader (ResourceLoader): An object which can load resources for the resolver. er_config (dict): A classifier config es_host (str): The Elasticsearch host server. es_client (Elasticsearch): The Elasticsearch client. """ er_config = config or get_classifier_config("entity_resolution", app_path=app_path) er_config = cls._correct_deprecated_er_config(er_config) resolver_type = er_config["model_settings"]["resolver_type"] cls._validate_resolver_type(resolver_type) resource_loader = ( resource_loader or ResourceLoader.create_resource_loader(app_path=app_path) ) return ENTITY_RESOLVER_MODEL_MAPPINGS.get(resolver_type)( app_path, entity_type, config=er_config, resource_loader=resource_loader, **kwargs)
[docs]class BaseEntityResolver(ABC): # pylint: disable=too-many-instance-attributes """ Base class for Entity Resolvers """ def __init__(self, app_path, entity_type, resource_loader=None, **_kwargs): """Initializes an entity resolver Args: app_path (str): The application path. entity_type (str): The entity type associated with this entity resolver. resource_loader (ResourceLoader, Optional): A resource loader object for the resolver. """ self.app_path = app_path self.type = entity_type self._resource_loader = ( resource_loader or ResourceLoader.create_resource_loader(app_path=self.app_path) ) self._model_settings = {} self._is_system_entity = Entity.is_system_entity(self.type) self._no_trainable_canonical_entity_map = False self.dirty = False # bool, True if exists any unsaved data/model that can be saved self.ready = False # bool, True if the model is already fitted or loaded self.hash = "" def __repr__(self): msg = "<{} ready: {!r}, dirty: {!r}, app_path: {!r}, entity_type: {!r}>" return msg.format(self.__class__.__name__, self.ready, self.dirty, self.app_path, self.type) @property def resolver_configurations(self): return self._model_settings @resolver_configurations.setter @abstractmethod def resolver_configurations(self, model_settings): """Sets the configurations for the resolver that are used while creating a dump of configs """ raise NotImplementedError
[docs] def fit(self, clean=False, entity_map=None): """Fits the resolver model, if required Args: clean (bool, optional): If ``True``, deletes and recreates the index from scratch with synonyms in the mapping.json. entity_map (Dict[str, Union[str, List]]): Entity map if passed in directly instead of loading from a file path Raises: EntityResolverError: if the resolver cannot be fit with the loaded/passed-in data Example of a entity_map.json file: --------------------------------- entity_map = { "some_optional_key": "value", "entities": [ { "id": "B01MTUORTQ", "cname": "Seaweed Salad", "whitelist": [...], }, ... ], } """ msg = f"Fitting {self.__class__.__name__} entity resolver for entity_type {self.type}" logger.info(msg) if self.ready and not clean: return if self._is_system_entity: self._no_trainable_canonical_entity_map = True self.ready = True self.dirty = True # configs need to be saved even for sys entities return entity_map = entity_map or self._get_entity_map() entities_data = entity_map.get("entities", []) if not entities_data: self._no_trainable_canonical_entity_map = True self.ready = True self.dirty = True return # obtain hash # hash based on the KB data before any processing new_hash = self._get_model_hash(entities_data) # see if a model is already available hash value cached_model_path = self._resource_loader.hash_to_model_path.get(new_hash) if cached_model_path: msg = f"A fit {self.__class__.__name__} model for the found KB data is already " \ f"available. Loading the model instead of fitting again. Pass 'clean=True' to " \ f"the .fit() method in case you wish to force a re-fitting." logger.info(msg) self.load(cached_model_path, entity_map=entity_map) return # reformat (if required) and fit the resolver model entity_map["entities"] = self._format_entity_map(entities_data) try: self._fit(clean, entity_map) except Exception as e: msg = f"Error in {self.__class__.__name__} while fitting the resolver model with " \ f"clean={clean}" raise EntityResolverError(msg) from e self.hash = new_hash self.ready = True self.dirty = True
[docs] def predict(self, entity_or_list_of_entities, top_n=DEFAULT_TOP_N, allowed_cnames=None): """Predicts the resolved value(s) for the given entity using the loaded entity map or the trained entity resolution model. Args: entity_or_list_of_entities (Entity, tuple[Entity], str, tuple[str]): One or more entity query strings or Entity objects that needs to be resolved. top_n (int, optional): maximum number of results to populate. If specifically inputted as 0 or `None`, results in an unsorted list of results in case of embedder and tfidf entity resolvers. This is sometimes helpful when a developer wishes to do some wrapper operations on top of unsorted results, such as combining scores from multiple resolvers and then sorting, etc. allowed_cnames (Iterable, optional): if inputted, predictions will only include objects related to these canonical names Returns: (list): The top n resolved values for the provided entity. Raises: EntityResolverError: if unable to obtain predictions for the given input """ if not self.ready: msg = "Resolver not ready, model must be built (.fit()) or loaded (.load()) first." logger.error(msg) nbest_entities = entity_or_list_of_entities if not isinstance(nbest_entities, (list, tuple)): nbest_entities = tuple([nbest_entities]) nbest_entities = tuple( [Entity(e, self.type) if isinstance(e, str) else e for e in nbest_entities] ) if self._is_system_entity: # system entities are already resolved top_entity = nbest_entities[0] return [top_entity.value] if self._no_trainable_canonical_entity_map: return [] if allowed_cnames: allowed_cnames = set(allowed_cnames) # order doesn't matter # unsorted list in case of tfidf and embedder models; sorted in case of Elasticsearch try: results = self._predict(nbest_entities, allowed_cnames) except Exception as e: msg = f"Error in {self.__class__.__name__} while resolving entities for the " \ f"input: {entity_or_list_of_entities}" raise EntityResolverError(msg) from e return self._trim_and_sort_results(results, top_n)
[docs] def dump(self, model_path, incremental_model_path=None): """ Persists the trained classification model to disk. The state for an embedder based model is the cached embeddings whereas for text features based resolvers, (if required,) it will generally be a serialized pickle of the underlying model/algorithm and the data associated. In general, this method leads to creation of the following files: - .configs.pkl: pickle of the resolver's configuarble parameters - .pkl.hash: a hash string obtained from a combination of KB data and the config params - .pkl (optional, for non-ES models): pickle of the underlying model/algo state - .embedder_cache.pkl (optional, for embedder models): pickle of underlying embeddings Args: model_path (str): A .pkl file path where the resolver will be dumped. The model hash will be dumped at {path}.hash file path incremental_model_path (str, optional): The timestamp folder where the cached models are stored. """ msg = f"Dumping {self.__class__.__name__} entity resolver for entity_type {self.type}" logger.info(msg) if not self.ready: msg = "Resolver not ready, model must be built (.fit()) before dumping." logger.error(msg) raise EntityResolverError(msg) for path in [model_path, incremental_model_path]: if not path: continue # underlying resolver model/algorithm/embeddings specific dump self._dump(path) # save resolver configs # in case of classifiers (domain, intent, etc.), dumping configs is handled by the # models abstract layer head, ext = os.path.splitext(path) resolver_config_path = head + ".config" + ext os.makedirs(os.path.dirname(resolver_config_path), exist_ok=True) with open(resolver_config_path, "wb") as fp: pickle.dump(self.resolver_configurations, fp) # save data hash # this hash is useful for avoiding re-fitting the resolver on unchanged data hash_path = path + ".hash" os.makedirs(os.path.dirname(hash_path), exist_ok=True) with open(hash_path, "w") as hash_file: hash_file.write(self.hash) if path == model_path: self.dirty = False
[docs] def load(self, path, entity_map=None): """ Loads state of the entity resolver as well the KB data. The state for embedder model is the cached embeddings whereas for text features based resolvers, (if required,) it will generally be a serialized pickle of the underlying model/algorithm. There is no state as such for Elasticsearch resolver to be dumped. Args: path (str): A .pkl file path where the resolver has been dumped entity_map (Dict[str, Union[str, List]]): Entity map if passed in directly instead of loading from a file path Raises: EntityResolverError: if the resolver cannot be loaded from the specified path """ msg = f"Loading {self.__class__.__name__} entity resolver for entity_type {self.type}" logger.info(msg) if self.ready: msg = f"The {self.__class__.__name__} entity resolver for entity_type {self.type} is " \ f"already loaded. If you wish to do a clean fit, you can call the fit method " \ f"as follows: .fit(clean=True)" logger.info(msg) return if self._is_system_entity: self._no_trainable_canonical_entity_map = True self.ready = True self.dirty = False return entity_map = entity_map or self._get_entity_map() entities_data = entity_map.get("entities", []) if not entities_data: self._no_trainable_canonical_entity_map = True self.ready = True self.dirty = False return # obtain hash # hash based on the KB data before any processing new_hash = self._get_model_hash(entities_data) hash_path = path + ".hash" with open(hash_path, "r") as hash_file: self.hash = hash_file.read() if new_hash != self.hash: msg = f"Found KB data to have changed when loading {self.__class__.__name__} " \ f"resolver ({str(self)}). Please fit using 'clean=True' " \ f"before loading a resolver fopr this KB. Found new data hash to be " \ f"'{new_hash}' whereas the hash during dumping was '{self.hash}'" logger.error(msg) raise ValueError(msg) # reformat (if required) entity_map["entities"] = self._format_entity_map(entities_data) # load resolver configs if it exists head, ext = os.path.splitext(path) resolver_config_path = head + ".config" + ext if os.path.exists(resolver_config_path): with open(resolver_config_path, "rb") as fp: self.resolver_configurations = pickle.load(fp) else: msg = f"Cannot find a configs path for the resolver while loading the " \ f"resolver:{self.__class__.__name__}. This could have happened if you missed " \ f"to call the .dump() method of resolver before calling the .load() method." logger.debug(msg) self.resolver_configurations = {} # load underlying resolver model/algorithm/embeddings try: self._load(path, entity_map=entity_map) except Exception as e: msg = f"Error in {self.__class__.__name__} while loading the resolver from the " \ f"path: {path}" raise EntityResolverError(msg) from e self.ready = True self.dirty = False
# TODO: method to be removed in a next major release of Mindmeld
[docs] @abstractmethod def load_deprecated(self): """ A method to handle the deprecated way of using the .load() method in entity resolvers. This ensures backwards compatibility when loading models that were built using an older version of Mindmeld i.e a version <=4.4.0. Since no hash pickle file is dumped in the older version of MindMeld, using the latest .load() method throws a FileNotFoundError. """ raise NotImplementedError
[docs] def unload(self): """ Unloads the model from memory. This helps reduce memory requirements while training other models. """ self._unload() self.resolver_configurations = {} self.ready = False
@abstractmethod def _fit(self, clean, entity_map): """Fits the entity resolver model Args: clean (bool): If ``True``, deletes and recreates the index from scratch instead of updating the existing index with synonyms in the mapping.json. entity_map (json): json data loaded from `mapping.json` file for the entity type """ raise NotImplementedError @staticmethod def _get_model_hash(entities_data): """Returns a hash representing the inputs into the model Args: entities_data (List[dict]): The entity objects in the KB used to fit this model Returns: str: The hash """ strings = sorted([json.dumps(ent_obj, sort_keys=True) for ent_obj in entities_data]) return Hasher(algorithm="sha256").hash_list(strings=[*strings, ]) def _get_entity_map(self, force_reload=False): try: return self._resource_loader.get_entity_map(self.type, force_reload=force_reload) except Exception as e: msg = f"Unable to load entity mapping data for " \ f"entity type: {self.type} in app_path: {self.app_path}" raise Exception(msg) from e @staticmethod def _format_entity_map(entities_data): """ Args: entities_data (List[dict]): A list of dictionary objects each consisting of a 'cname' (canonical name string), 'whitelist' (a list of zero or more synonyms) and 'id' (a unique idenfier for the set of cname and whitelist) Returns: entities_data (List[dict]): A reformatted entities_data list Raise: valueError: if any object has missing cname as well as whitelist """ all_ids = set() for i, ent_object in enumerate(entities_data): _id = ent_object.get("id") cname = ent_object.get("cname") whitelist = list(dict.fromkeys(ent_object.get("whitelist", []))) if cname is None and len(whitelist) == 0: msg = f"Found no canonical name field 'cname' while processing KB objects. " \ f"The observed KB entity object is: {ent_object}" raise ValueError(msg) elif cname is None and len(whitelist): cname = whitelist[0] whitelist = whitelist[1:] if _id in all_ids: msg = f"Found a duplicate id {_id} while formatting data for entity resolution. " _id = uuid.uuid4() msg += f"Replacing it with a new id: {_id}" logger.warning(msg) if not _id: _id = uuid.uuid4() msg = f"Found an entry in entity_map without a corresponding id. " \ f"Creating a random new id ({_id}) for this object." logger.warning(msg) _id = str(_id) all_ids.update([_id]) entities_data[i] = {"id": _id, "cname": cname, "whitelist": whitelist} return entities_data def _process_entities( self, entities, normalizer=None, augment_lower_case=False, augment_title_case=False, augment_normalized=False, normalize_aliases=False ): """ Loads in the mapping.json file and stores the synonym mappings in a item_map and a synonym_map Args: entities (list[dict]): List of dictionaries with keys `id`, `cname` and `whitelist` normalizer (callable): The normalizer to use, if provided, used to normalize synonyms augment_lower_case (bool): If to extend the synonyms list with their lower-cased values augment_title_case (bool): If to extend the synonyms list with their title-cased values augment_normalized (bool): If to extend the synonyms list with their normalized values, uses the provided normalizer """ do_mutate_strings = any([augment_lower_case, augment_title_case, augment_normalized]) if do_mutate_strings: msg = "Adding additional form of the whitelist and cnames to list of possible synonyms" logger.info(msg) item_map = {} syn_map = {} seen_ids = [] for item in entities: item_id = item.get("id") cname = item["cname"] if cname in item_map: msg = "Canonical name %s specified in %s entity map multiple times" logger.debug(msg, cname, self.type) if item_id and item_id in seen_ids: msg = "Id %s specified in %s entity map multiple times" raise ValueError(msg.format(item_id, self.type)) seen_ids.append(item_id) aliases = [cname] + item.pop("whitelist", []) if do_mutate_strings: new_aliases = [] if augment_lower_case: new_aliases.extend([string.lower() for string in aliases]) if augment_title_case: new_aliases.extend([string.title() for string in aliases]) if augment_normalized and normalizer: new_aliases.extend([normalizer(string) for string in aliases]) aliases = {*aliases, *new_aliases} if normalize_aliases and normalizer: aliases = [normalizer(alias) for alias in aliases] items_for_cname = item_map.get(cname, []) items_for_cname.append(item) item_map[cname] = items_for_cname for alias in aliases: if alias in syn_map: msg = "Synonym %s specified in %s entity map multiple times" logger.debug(msg, cname, self.type) cnames_for_syn = syn_map.get(alias, []) cnames_for_syn.append(cname) syn_map[alias] = list(set(cnames_for_syn)) return {"items": item_map, "synonyms": syn_map} @abstractmethod def _predict(self, nbest_entities, allowed_cnames=None): """Predicts the resolved value(s) for the given entity using cosine similarity. Args: nbest_entities (tuple): List of one entity object found in an input query, or a list \ of n-best entity objects. allowed_cnames (set, optional): if inputted, predictions will only include objects related to these canonical names Returns: (list): The resolved values for the provided entity. """ raise NotImplementedError def _trim_and_sort_results(self, results, top_n): """ Trims down the results generated by any ER class, finally populating at max top_n documents Args: results (list[dict]): Each element in this list is a result dictions with keys such as `id` (optional), `cname`, `score` and any others top_n (int): Number of top documents required to be populated Returns: list[dict]: if trimmed, a list similar to `results` but with fewer elements, else, the `results` list as-is is returned """ if not results: return [] if not isinstance(top_n, int) or top_n <= 0: msg = f"The value of 'top_n' set to '{top_n}' during predictions in " \ f"{self.__class__.__name__}. This will result in an unsorted list of documents. " logger.info(msg) return results # Obtain top scored result for each doc id (only if scores field exist in results) best_results = {} for result in results: if "score" not in result: return results # use cname as id if no `id` field exist in results _id = result["id"] if "id" in result else result["cname"] if _id not in best_results or result["score"] > best_results[_id]["score"]: best_results[_id] = result results = [*best_results.values()] # Obtain upto top_n docs and sort them as final result n_scores = len(results) if n_scores < top_n and top_n != DEFAULT_TOP_N: # log only if a value other than default value is specified msg = f"Retrieved only {len(results)} entity resolutions instead of asked " \ f"number {top_n} for entity type {self.type}" logger.info(msg) elif n_scores > top_n: # select the top_n by using argpartition as it is faster than sorting _sim_scores = np.asarray([val["score"] for val in results]) _top_inds = _sim_scores.argpartition(n_scores - top_n)[-top_n:] results = [results[ind] for ind in _top_inds] # trimmed list of top_n docs return sorted(results, key=lambda x: x["score"], reverse=True) def _dump(self, path): pass def _load(self, path, entity_map): pass def _unload(self): pass
[docs]class ExactMatchEntityResolver(BaseEntityResolver): """ Resolver class based on exact matching """ def __init__(self, app_path, entity_type, **kwargs): """ Args: app_path (str): The application path. entity_type (str): The entity type associated with this entity resolver. resource_loader (ResourceLoader, Optional): A resource loader object for the resolver. config (dict): Configurations can be passed in through `model_settings` field `model_settings` (dict): Following keys are configurable: augment_lower_case (bool): to augment lowercased synonyms as whitelist augment_title_case (bool): to augment titlecased synonyms as whitelist augment_normalized (bool): to augment text normalized synonyms as whitelist """ super().__init__(app_path, entity_type, **kwargs) self.resolver_configurations = kwargs.get("config", {}).get("model_settings", {}) self.processed_entity_map = None @BaseEntityResolver.resolver_configurations.setter def resolver_configurations(self, model_settings): self._model_settings = model_settings or {} self._aug_lower_case = self._model_settings.get("augment_lower_case", False) self._aug_title_case = self._model_settings.get("augment_title_case", False) self._aug_normalized = self._model_settings.get("augment_normalized", False) self._normalize_aliases = True self._model_settings.update({ "augment_lower_case": self._aug_lower_case, "augment_title_case": self._aug_title_case, "augment_normalized": self._aug_normalized, "normalize_aliases": self._normalize_aliases, })
[docs] def get_processed_entity_map(self, entity_map): """ Processes the entity map into a format suitable for indexing and similarity searching Args: entity_map (Dict[str, Union[str, List]]): Entity map if passed in directly instead of loading from a file path Returns: processed_entity_map (Dict): A processed entity map better suited for indexing and querying """ return self._process_entities( entity_map.get("entities", []), normalizer=self._resource_loader.query_factory.normalize, augment_lower_case=self._aug_lower_case, augment_title_case=self._aug_title_case, augment_normalized=self._aug_normalized, normalize_aliases=self._normalize_aliases )
def _fit(self, clean, entity_map): self.processed_entity_map = self.get_processed_entity_map(entity_map) if clean: msg = f"clean=True ignored while fitting {self.__class__.__name__}" logger.info(msg) def _predict(self, nbest_entities, allowed_cnames=None): """Looks for exact name in the synonyms data """ entity = nbest_entities[0] # top_entity normed = self._resource_loader.query_factory.normalize(entity.text) try: cnames = self.processed_entity_map["synonyms"][normed] except (KeyError, TypeError): logger.warning( "Failed to resolve entity %r for type %r", entity.text, entity.type ) return [] if len(cnames) > 1: logger.info( "Multiple possible canonical names for %r entity for type %r", entity.text, entity.type, ) values = [] for cname in cnames: if allowed_cnames and cname not in allowed_cnames: continue for item in self.processed_entity_map["items"][cname]: item_value = copy.copy(item) item_value.pop("whitelist", None) values.append(item_value) return values def _load(self, path, entity_map): self.processed_entity_map = self.get_processed_entity_map(entity_map) def _unload(self): self.processed_entity_map = None
[docs] def load_deprecated(self): self.fit()
[docs]class ElasticsearchEntityResolver(BaseEntityResolver): """ Resolver class based on Elastic Search """ # prefix for Elasticsearch indices used to store synonyms for entity resolution ES_SYNONYM_INDEX_PREFIX = "synonym" """The prefix of the ES index.""" def __init__(self, app_path, entity_type, **kwargs): """ Args: app_path (str): The application path. entity_type (str): The entity type associated with this entity resolver. resource_loader (ResourceLoader, Optional): A resource loader object for the resolver. es_host (str): The Elasticsearch host server es_client (Elasticsearch): an elastic search client config (dict): Configurations can be passed in through `model_settings` field `model_settings` (dict): Following keys are configurable: phonetic_match_types (List): a list of phonetic match types that are passed to Elasticsearch. Currently supports only using "double_metaphone" string in the list. """ super().__init__(app_path, entity_type, **kwargs) self.resolver_configurations = kwargs.get("config", {}).get("model_settings", {}) self._es_host = kwargs.get("es_host") self._es_config = {"client": kwargs.get("es_client"), "pid": os.getpid()} self._app_namespace = get_app_namespace(self.app_path) @BaseEntityResolver.resolver_configurations.setter def resolver_configurations(self, model_settings): self._model_settings = model_settings or {} self._use_double_metaphone = "double_metaphone" in ( self._model_settings.get("phonetic_match_types", []) ) @property def _es_index_name(self): return f"{ElasticsearchEntityResolver.ES_SYNONYM_INDEX_PREFIX}_{self.type}" @property def _es_client(self): # Lazily connect to Elasticsearch. Make sure each subprocess gets it's own connection if self._es_config["client"] is None or self._es_config["pid"] != os.getpid(): self._es_config = {"pid": os.getpid(), "client": create_es_client()} return self._es_config["client"]
[docs] @staticmethod def ingest_synonym( app_namespace, index_name, index_type=INDEX_TYPE_SYNONYM, field_name=None, data=None, es_host=None, es_client=None, use_double_metaphone=False, ): """Loads synonym documents from the mapping.json data into the specified index. If an index with the specified name doesn't exist, a new index with that name will be created. Args: app_namespace (str): The namespace of the app. Used to prevent collisions between the indices of this app and those of other apps. index_name (str): The name of the new index to be created. index_type (str): specify whether to import to synonym index or knowledge base object index. INDEX_TYPE_SYNONYM is the default which indicates the synonyms to be imported to synonym index, while INDEX_TYPE_KB indicates that the synonyms should be imported into existing knowledge base index. field_name (str): specify name of the knowledge base field that the synonym list corresponds to when index_type is INDEX_TYPE_SYNONYM. data (list): A list of documents to be loaded into the index. es_host (str): The Elasticsearch host server. es_client (Elasticsearch): The Elasticsearch client. use_double_metaphone (bool): Whether to use the phonetic mapping or not. """ data = data or [] def _action_generator(docs): for doc in docs: action = {} # id if doc.get("id"): action["_id"] = doc["id"] else: # generate hash from canonical name as ID action["_id"] = hashlib.sha256( doc.get("cname").encode("utf-8") ).hexdigest() # synonym whitelist whitelist = doc["whitelist"] syn_list = [] syn_list.append({"name": doc["cname"]}) for syn in whitelist: syn_list.append({"name": syn}) # If index type is INDEX_TYPE_KB we import the synonym into knowledge base object # index by updating the knowledge base object with additional synonym whitelist # field. Otherwise, by default we import to synonym index in ES. if index_type == INDEX_TYPE_KB and field_name: syn_field = field_name + "$whitelist" action["_op_type"] = "update" action["doc"] = {syn_field: syn_list} else: action.update(doc) action["whitelist"] = syn_list yield action mapping = ( PHONETIC_ES_SYNONYM_MAPPING if use_double_metaphone else DEFAULT_ES_SYNONYM_MAPPING ) es_client = es_client or create_es_client(es_host) mapping = resolve_es_config_for_version(mapping, es_client) load_index( app_namespace, index_name, _action_generator(data), len(data), mapping, DOC_TYPE, es_host, es_client, )
def _fit(self, clean, entity_map): """Loads an entity mapping file to Elasticsearch for text relevance based entity resolution. In addition, the synonyms in entity mapping are imported to knowledge base indexes if the corresponding knowledge base object index and field name are specified for the entity type. The synonym info is then used by Question Answerer for text relevance matches. """ try: if clean: delete_index( self._app_namespace, self._es_index_name, self._es_host, self._es_client ) except ValueError as e: # when `clean = True` but no index to delete logger.error(e) entities = entity_map.get("entities", []) # create synonym index and import synonyms logger.info("Importing synonym data to synonym index '%s'", self._es_index_name) self.ingest_synonym( app_namespace=self._app_namespace, index_name=self._es_index_name, data=entities, es_host=self._es_host, es_client=self._es_client, use_double_metaphone=self._use_double_metaphone, ) # It's supported to specify the KB object type and field name that the NLP entity type # corresponds to in the mapping.json file. In this case the synonym whitelist is also # imported to KB object index and the synonym info will be used when using Question Answerer # for text relevance matches. kb_index = entity_map.get("kb_index_name") kb_field = entity_map.get("kb_field_name") # if KB index and field name is specified then also import synonyms into KB object index. if kb_index and kb_field: # validate the KB index and field are valid. # TODO: this validation can probably be in some other places like resource loader. if not does_index_exist( self._app_namespace, kb_index, self._es_host, self._es_client ): raise ValueError( "Cannot import synonym data to knowledge base. The knowledge base " "index name '{}' is not valid.".format(kb_index) ) if kb_field not in get_field_names( self._app_namespace, kb_index, self._es_host, self._es_client ): raise ValueError( "Cannot import synonym data to knowledge base. The knowledge base " "field name '{}' is not valid.".format(kb_field) ) if entities and not entities[0].get("id"): raise ValueError( "Knowledge base index and field cannot be specified for entities " "without ID." ) logger.info("Importing synonym data to knowledge base index '%s'", kb_index) ElasticsearchEntityResolver.ingest_synonym( app_namespace=self._app_namespace, index_name=kb_index, index_type="kb", field_name=kb_field, data=entities, es_host=self._es_host, es_client=self._es_client, use_double_metaphone=self._use_double_metaphone, ) def _predict(self, nbest_entities, allowed_cnames=None): """Predicts the resolved value(s) for the given entity using the loaded entity map or the trained entity resolution model. Args: nbest_entities (tuple): List of one entity object found in an input query, or a list \ of n-best entity objects. Returns: (list): The resolved values for the provided entity. """ if allowed_cnames: msg = f"Cannot set 'allowed_cnames' param for {self.__class__.__name__}." raise NotImplementedError(msg) top_entity = nbest_entities[0] weight_factors = [1 - float(i) / len(nbest_entities) for i in range(len(nbest_entities))] def _construct_match_query(entity, weight=1): return [ { "match": { "cname.normalized_keyword": { "query": entity.text, "boost": 10 * weight, } } }, {"match": {"cname.raw": {"query": entity.text, "boost": 10 * weight}}}, { "match": { "cname.char_ngram": {"query": entity.text, "boost": weight} } }, ] def _construct_nbest_match_query(entity, weight=1): return [ { "match": { "cname.normalized_keyword": { "query": entity.text, "boost": weight, } } } ] def _construct_phonetic_match_query(entity, weight=1): return [ { "match": { "cname.double_metaphone": { "query": entity.text, "boost": 2 * weight, } } } ] def _construct_whitelist_query(entity, weight=1, use_phons=False): query = { "nested": { "path": "whitelist", "score_mode": "max", "query": { "bool": { "should": [ { "match": { "whitelist.name.normalized_keyword": { "query": entity.text, "boost": 10 * weight, } } }, { "match": { "whitelist.name": { "query": entity.text, "boost": weight, } } }, { "match": { "whitelist.name.char_ngram": { "query": entity.text, "boost": weight, } } }, ] } }, "inner_hits": {}, } } if use_phons: query["nested"]["query"]["bool"]["should"].append( { "match": { "whitelist.double_metaphone": { "query": entity.text, "boost": 3 * weight, } } } ) return query text_relevance_query = { "query": { "function_score": { "query": {"bool": {"should": []}}, "field_value_factor": { "field": "sort_factor", "modifier": "log1p", "factor": 10, "missing": 0, }, "boost_mode": "sum", "score_mode": "sum", } } } match_query = [] top_transcript = True for e, weight in zip(nbest_entities, weight_factors): if top_transcript: match_query.extend(_construct_match_query(e, weight)) top_transcript = False else: match_query.extend(_construct_nbest_match_query(e, weight)) if self._use_double_metaphone: match_query.extend(_construct_phonetic_match_query(e, weight)) text_relevance_query["query"]["function_score"]["query"]["bool"][ "should" ].append({"bool": {"should": match_query}}) whitelist_query = _construct_whitelist_query( top_entity, use_phons=self._use_double_metaphone ) text_relevance_query["query"]["function_score"]["query"]["bool"][ "should" ].append(whitelist_query) try: index = get_scoped_index_name(self._app_namespace, self._es_index_name) response = self._es_client.search(index=index, body=text_relevance_query) except _getattr("elasticsearch", "ConnectionError") as ex: logger.error( "Unable to connect to Elasticsearch: %s details: %s", ex.error, ex.info ) raise ElasticsearchConnectionError(es_host=self._es_client.transport.hosts) from ex except _getattr("elasticsearch", "TransportError") as ex: logger.error( "Unexpected error occurred when sending requests to Elasticsearch: %s " "Status code: %s details: %s", ex.error, ex.status_code, ex.info, ) raise EntityResolverError( "Unexpected error occurred when sending requests to " "Elasticsearch: {} Status code: {} details: " "{}".format(ex.error, ex.status_code, ex.info) ) from ex except _getattr("elasticsearch", "ElasticsearchException") as ex: raise EntityResolverError from ex else: hits = response["hits"]["hits"] results = [] for hit in hits: if self._use_double_metaphone and len(nbest_entities) > 1: if hit["_score"] < 0.5 * len(nbest_entities): continue top_synonym = None synonym_hits = hit["inner_hits"]["whitelist"]["hits"]["hits"] if synonym_hits: top_synonym = synonym_hits[0]["_source"]["name"] result = { "cname": hit["_source"]["cname"], "score": hit["_score"], "top_synonym": top_synonym, } if hit["_source"].get("id"): result["id"] = hit["_source"].get("id") if hit["_source"].get("sort_factor"): result["sort_factor"] = hit["_source"].get("sort_factor") results.append(result) return results def _load(self, path, entity_map): del path try: scoped_index_name = get_scoped_index_name( self._app_namespace, self._es_index_name ) if not self._es_client.indices.exists(index=scoped_index_name): self.fit(entity_map=entity_map) except _getattr("elasticsearch", "ConnectionError") as e: logger.error( "Unable to connect to Elasticsearch: %s details: %s", e.error, e.info ) raise ElasticsearchConnectionError(es_host=self._es_client.transport.hosts) from e except _getattr("elasticsearch", "TransportError") as e: logger.error( "Unexpected error occurred when sending requests to Elasticsearch: %s " "Status code: %s details: %s", e.error, e.status_code, e.info, ) raise EntityResolverError from e except _getattr("elasticsearch", "ElasticsearchException") as e: raise EntityResolverError from e
[docs] def load_deprecated(self): try: scoped_index_name = get_scoped_index_name( self._app_namespace, self._es_index_name ) if not self._es_client.indices.exists(index=scoped_index_name): self.fit() except _getattr("elasticsearch", "ConnectionError") as e: logger.error( "Unable to connect to Elasticsearch: %s details: %s", e.error, e.info ) raise ElasticsearchConnectionError(es_host=self._es_client.transport.hosts) from e except _getattr("elasticsearch", "TransportError") as e: logger.error( "Unexpected error occurred when sending requests to Elasticsearch: %s " "Status code: %s details: %s", e.error, e.status_code, e.info, ) raise EntityResolverError from e except _getattr("elasticsearch", "ElasticsearchException") as e: raise EntityResolverError from e
[docs]class TfIdfSparseCosSimEntityResolver(BaseEntityResolver): # pylint: disable=too-many-instance-attributes """ a tf-idf based entity resolver using sparse matrices. ref: scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.TfidfVectorizer.html """ def __init__(self, app_path, entity_type, **kwargs): """ Args: app_path (str): The application path. entity_type (str): The entity type associated with this entity resolver. resource_loader (ResourceLoader, Optional): A resource loader object for the resolver. config (dict): Configurations can be passed in through `model_settings` field `model_settings`: augment_lower_case: to augment lowercased synonyms as whitelist augment_title_case: to augment titlecased synonyms as whitelist augment_normalized: to augment text normalized synonyms as whitelist augment_max_synonyms_embeddings: to augment pooled synonyms whose embedding is max-pool of all whitelist's (including above alterations) encodings """ super().__init__(app_path, entity_type, **kwargs) self.resolver_configurations = kwargs.get("config", {}).get("model_settings", {}) self.processed_entity_map = None self._analyzer = self._char_ngrams_plus_words_analyzer self._unique_synonyms = [] self._syn_tfidf_matrix = None self._vectorizer = None @BaseEntityResolver.resolver_configurations.setter def resolver_configurations(self, model_settings): self._model_settings = model_settings or {} self._aug_lower_case = self._model_settings.get("augment_lower_case", True) self._aug_title_case = self._model_settings.get("augment_title_case", False) self._aug_normalized = self._model_settings.get("augment_normalized", False) self._aug_max_syn_embs = self._model_settings.get("augment_max_synonyms_embeddings", True) self._normalize_aliases = False self.ngram_length = 5 # max number of character ngrams to consider; 3 for elasticsearch self._model_settings.update({ "augment_lower_case": self._aug_lower_case, "augment_title_case": self._aug_title_case, "augment_normalized": self._aug_normalized, "augment_max_synonyms_embeddings": self._aug_max_syn_embs, "normalize_aliases": self._normalize_aliases, "ngram_length": self.ngram_length, })
[docs] def get_processed_entity_map(self, entity_map): """ Processes the entity map into a format suitable for indexing and similarity searching Args: entity_map (Dict[str, Union[str, List]]): Entity map if passed in directly instead of loading from a file path Returns: processed_entity_map (Dict): A processed entity map better suited for indexing and querying """ return self._process_entities( entity_map.get("entities", []), normalizer=self._resource_loader.query_factory.normalize, augment_lower_case=self._aug_lower_case, augment_title_case=self._aug_title_case, augment_normalized=self._aug_normalized, normalize_aliases=self._normalize_aliases )
def _fit(self, clean, entity_map): self.processed_entity_map = self.get_processed_entity_map(entity_map) if clean: msg = f"clean=True ignored while fitting {self.__class__.__name__}" logger.info(msg) self._vectorizer = TfidfVectorizer(analyzer=self._analyzer, lowercase=False) # obtain sparse matrix synonyms = {v: k for k, v in dict(enumerate(set(self.processed_entity_map["synonyms"]))).items()} synonyms_embs = self._vectorizer.fit_transform([*synonyms.keys()]) # encode artificial synonyms if required if self._aug_max_syn_embs: # obtain cnames to synonyms mapping synonym2cnames = self.processed_entity_map["synonyms"] cname2synonyms = {} for syn, cnames in synonym2cnames.items(): for cname in cnames: items = cname2synonyms.get(cname, []) items.append(syn) cname2synonyms[cname] = items pooled_cnames, pooled_cnames_encodings = [], [] # assert pooled synonyms for cname, syns in cname2synonyms.items(): syns = list(set(syns)) if len(syns) == 1: continue pooled_cname = f"{cname} - SYNONYMS AVERAGE" # update synonyms map 'cause such synonyms don't actually exist in mapping.json file pooled_cname_aliases = synonym2cnames.get(pooled_cname, []) pooled_cname_aliases.append(cname) synonym2cnames[pooled_cname] = pooled_cname_aliases # check if needs to be encoded if pooled_cname in synonyms: continue # if required, obtain pooled encoding and update collections pooled_encoding = scipy.sparse.csr_matrix( np.max([synonyms_embs[synonyms[syn]].toarray() for syn in syns], axis=0) ) pooled_cnames.append(pooled_cname) pooled_cnames_encodings.append(pooled_encoding) if pooled_cnames_encodings: pooled_cnames_encodings = scipy.sparse.vstack(pooled_cnames_encodings) if pooled_cnames: synonyms_embs = ( pooled_cnames_encodings if not synonyms else scipy.sparse.vstack( [synonyms_embs, pooled_cnames_encodings]) ) synonyms.update( OrderedDict(zip( pooled_cnames, np.arange(len(synonyms), len(synonyms) + len(pooled_cnames))) ) ) # returns a sparse matrix self._unique_synonyms = [*synonyms.keys()] self._syn_tfidf_matrix = synonyms_embs def _predict(self, nbest_entities, allowed_cnames=None): # encode input entity top_entity = nbest_entities[0] # top_entity try: scored_items = self.find_similarity(top_entity.text, _no_sort=True) values = [] for synonym, score in scored_items: cnames = self.processed_entity_map["synonyms"][synonym] for cname in cnames: if allowed_cnames and cname not in allowed_cnames: continue for item in self.processed_entity_map["items"][cname]: item_value = copy.copy(item) item_value.pop("whitelist", None) item_value.update({"score": score}) item_value.update({"top_synonym": synonym}) values.append(item_value) except KeyError as e: msg = f"Failed to resolve entity {top_entity.text} for type {top_entity.type}; set " \ f"'clean=True' for computing TF-IDF of newly added items in mappings.json" logger.error(str(e)) logger.error(msg) return [] except TypeError as f: msg = f"Failed to resolve entity {top_entity.text} for type {top_entity.type}" logger.error(str(f)) logger.error(msg) return [] return values def _dump(self, path): resolver_state = { "unique_synonyms": self._unique_synonyms, # caching unique syns for finding similarity "syn_tfidf_matrix": self._syn_tfidf_matrix, # caching sparse vectors of synonyms "vectorizer": self._vectorizer, # caching vectorizer } os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "wb") as fp: pickle.dump(resolver_state, fp) def _load(self, path, entity_map): self.processed_entity_map = self.get_processed_entity_map(entity_map) with open(path, "rb") as fp: resolver_state = pickle.load(fp) self._unique_synonyms = resolver_state["unique_synonyms"] self._syn_tfidf_matrix = resolver_state["syn_tfidf_matrix"] self._vectorizer = resolver_state["vectorizer"] def _unload(self): self.processed_entity_map = None self._unique_synonyms = [] self._syn_tfidf_matrix = None self._vectorizer = None def _char_ngrams_plus_words_analyzer(self, string): """ Analyzer that accounts for character ngrams as well as individual words in the input """ # get char ngrams results = self._char_ngrams_analyzer(string) # add individual words words = re.split(r'[\s{}]+'.format(re.escape(punctuation)), string.strip()) results.extend(words) return results def _char_ngrams_analyzer(self, string): """ Analyzer that only accounts for character ngrams from size 1 to self.ngram_length """ string = string.strip() if len(string) == 1: return [string] results = [] # give importance to starting and ending characters of a word string = f" {string} " for n in range(self.ngram_length + 1): results.extend([''.join(gram) for gram in zip(*[string[i:] for i in range(n)])]) results = list(set(results)) results.remove(' ') # adding lowercased single characters might add more noise results = [r for r in results if not (len(r) == 1 and r.islower())] # returns empty list of an empty string return results
[docs] def find_similarity( self, src_texts, top_n=DEFAULT_TOP_N, scores_normalizer=None, _return_as_dict=False, _no_sort=False ): """Computes sparse cosine similarity Args: src_texts (Union[str, list]): string or list of strings to obtain matching scores for. top_n (int, optional): maximum number of results to populate. if None, equals length of self._syn_tfidf_matrix scores_normalizer (str, optional): normalizer type to normalize scores. Allowed values are: "min_max_scaler", "standard_scaler" _return_as_dict (bool, optional): if the results should be returned as a dictionary of target_text name as keys and scores as corresponding values _no_sort (bool, optional): If True, results are returned without sorting. This is helpful at times when you wish to do additional wrapper operations on top of raw results and would like to save computational time without sorting. Returns: Union[dict, list[tuple]]: if _return_as_dict, returns a dictionary of tgt_texts and their scores, else a list of sorted synonym names paired with their similarity scores (descending order) """ is_single = False if isinstance(src_texts, str): is_single = True src_texts = [src_texts] top_n = self._syn_tfidf_matrix.shape[0] if not top_n else top_n results = [] for src_text in src_texts: src_text_vector = self._vectorizer.transform([src_text]) similarity_scores = self._syn_tfidf_matrix.dot(src_text_vector.T).toarray().reshape(-1) # Rounding sometimes helps to bring correct answers on to the # top score as other non-correct resolutions similarity_scores = np.around(similarity_scores, decimals=4) if scores_normalizer: if scores_normalizer == "min_max_scaler": _min = np.min(similarity_scores) _max = np.max(similarity_scores) denominator = (_max - _min) if (_max - _min) != 0 else 1.0 similarity_scores = (similarity_scores - _min) / denominator elif scores_normalizer == "standard_scaler": _mean = np.mean(similarity_scores) _std = np.std(similarity_scores) denominator = _std if _std else 1.0 similarity_scores = (similarity_scores - _mean) / denominator else: msg = f"Allowed values for `scores_normalizer` are only " \ f"{['min_max_scaler', 'standard_scaler']}. Continuing without " \ f"normalizing similarity scores." logger.error(msg) if _return_as_dict: results.append(dict(zip(self._unique_synonyms, similarity_scores))) else: if not _no_sort: # sort results in descending scores n_scores = len(similarity_scores) if n_scores > top_n: top_inds = similarity_scores.argpartition(n_scores - top_n)[-top_n:] result = sorted( [(self._unique_synonyms[ii], similarity_scores[ii]) for ii in top_inds], key=lambda x: x[1], reverse=True) else: result = sorted(zip(self._unique_synonyms, similarity_scores), key=lambda x: x[1], reverse=True) results.append(result) else: result = list(zip(self._unique_synonyms, similarity_scores)) results.append(result) if is_single: return results[0] return results
[docs] def load_deprecated(self): self.fit()
[docs]class EmbedderCosSimEntityResolver(BaseEntityResolver): """ Resolver class for embedder models that create dense embeddings """ def __init__(self, app_path, entity_type, **kwargs): """ Args: app_path (str): The application path. entity_type (str): The entity type associated with this entity resolver. resource_loader (ResourceLoader, Optional): A resource loader object for the resolver. config (dict): Configurations can be passed in through `model_settings` field `model_settings`: embedder_type: the type of embedder picked from embedder_models.py class (eg. 'bert', 'glove', etc. ) augment_lower_case: to augment lowercased synonyms as whitelist augment_title_case: to augment titlecased synonyms as whitelist augment_normalized: to augment text normalized synonyms as whitelist augment_average_synonyms_embeddings: to augment pooled synonyms whose embedding is average of all whitelist's (including above alterations) encodings embedder_cache_path (str): A path where the embedder cache can be stored. If it is not specified, an embedder will be instantiated using the app_path information. If specified, it will be used to dump the embeddings cache. """ super().__init__(app_path, entity_type, **kwargs) self.resolver_configurations = kwargs.get("config", {}).get("model_settings", {}) self.processed_entity_map = None self._embedder_model = None @BaseEntityResolver.resolver_configurations.setter def resolver_configurations(self, model_settings): self._model_settings = model_settings or {} self._aug_lower_case = self._model_settings.get("augment_lower_case", False) self._aug_title_case = self._model_settings.get("augment_title_case", False) self._aug_normalized = self._model_settings.get("augment_normalized", False) self._aug_avg_syn_embs = self._model_settings.get( "augment_average_synonyms_embeddings", True) self._normalize_aliases = False self._model_settings.update({ "augment_lower_case": self._aug_lower_case, "augment_title_case": self._aug_title_case, "augment_normalized": self._aug_normalized, "normalize_aliases": self._normalize_aliases, "augment_max_synonyms_embeddings": self._aug_avg_syn_embs, })
[docs] def get_processed_entity_map(self, entity_map): """ Processes the entity map into a format suitable for indexing and similarity searching Args: entity_map (Dict[str, Union[str, List]]): Entity map if passed in directly instead of loading from a file path Returns: processed_entity_map (Dict): A processed entity map better suited for indexing and querying """ return self._process_entities( entity_map.get("entities", []), normalizer=self._resource_loader.query_factory.normalize, augment_lower_case=self._aug_lower_case, augment_title_case=self._aug_title_case, augment_normalized=self._aug_normalized, normalize_aliases=self._normalize_aliases )
def _fit(self, clean, entity_map): self.processed_entity_map = self.get_processed_entity_map(entity_map) self._embedder_model = create_embedder_model( app_path=self.app_path, config=self.resolver_configurations ) if clean: msg = f"clean=True ignored while fitting {self.__class__.__name__}" logger.info(msg) # load embeddings from cache if exists, encode any other synonyms if required self._embedder_model.get_encodings([*self.processed_entity_map["synonyms"].keys()]) # encode artificial synonyms if required if self._aug_avg_syn_embs: # obtain cnames to synonyms mapping cname2synonyms = {} for syn, cnames in self.processed_entity_map["synonyms"].items(): for cname in cnames: cname2synonyms[cname] = cname2synonyms.get(cname, []) + [syn] # create and add superficial data for cname, syns in cname2synonyms.items(): syns = list(set(syns)) if len(syns) == 1: continue pooled_cname = f"{cname} - SYNONYMS AVERAGE" # update synonyms map 'cause such synonyms don't actually exist in mapping.json file if pooled_cname not in self.processed_entity_map["synonyms"]: self.processed_entity_map["synonyms"][pooled_cname] = [cname] # obtain encoding and update cache # TODO: asumption that embedding cache has __getitem__ can be addressed if pooled_cname in self._embedder_model.cache: continue pooled_encoding = np.mean(self._embedder_model.get_encodings(syns), axis=0) self._embedder_model.add_to_cache({pooled_cname: pooled_encoding}) # useful for validation while loading self._model_settings["embedder_model_id"] = self._embedder_model.model_id # snippet for backwards compatibility # even if the .dump() method of resolver isn't called explicitly, the embeddings need to be # cached for fast inference of resolver; however, with the introduction of dump() and # load() methods, this temporary persisting is not necessary and must be removed in future # versions self._embedder_model.dump_cache() def _predict(self, nbest_entities, allowed_cnames=None): """Predicts the resolved value(s) for the given entity using cosine similarity. """ # encode input entity top_entity = nbest_entities[0] # top_entity allowed_syns = None if allowed_cnames: syn2cnames = self.processed_entity_map["synonyms"] allowed_syns = [syn for syn, cnames in syn2cnames.items() if any([cname in allowed_cnames for cname in cnames])] try: scored_items = self._embedder_model.find_similarity( top_entity.text, tgt_texts=allowed_syns, _no_sort=True) values = [] for synonym, score in scored_items: cnames = self.processed_entity_map["synonyms"][synonym] for cname in cnames: if allowed_cnames and cname not in allowed_cnames: continue for item in self.processed_entity_map["items"][cname]: item_value = copy.copy(item) item_value.pop("whitelist", None) item_value.update({"score": score}) item_value.update({"top_synonym": synonym}) values.append(item_value) except KeyError as e: msg = f"Failed to resolve entity {top_entity.text} for type {top_entity.type}; set " \ f"'clean=True' for computing embeddings of newly added items in mappings.json" logger.error(str(e)) logger.error(msg) return [] except TypeError as f: msg = f"Failed to resolve entity {top_entity.text} for type {top_entity.type}" logger.error(str(f)) logger.error(msg) return [] except RuntimeError as r: # happens when the input is an empty string and an embedder models fails to embed it msg = f"Failed to resolve entity {top_entity.text} for type {top_entity.type}" if "mat1 and mat2 shapes cannot be multiplied" in str(r): msg += ". This can happen if the input passed to embedder is an empty string!" logger.error(str(r)) logger.error(msg) raise RuntimeError(msg) from r return values def _dump(self, path): # kept due to backwards compatibility in _fit(), must be removed in future versions self._embedder_model.clear_cache() # delete the temp cache as .dump() method is now used head, ext = os.path.splitext(path) embedder_cache_path = head + ".embedder_cache" + ext self._embedder_model.dump_cache(cache_path=embedder_cache_path) self._model_settings["embedder_cache_path"] = embedder_cache_path def _load(self, path, entity_map): self.processed_entity_map = self.get_processed_entity_map(entity_map) self._embedder_model = create_embedder_model( app_path=self.app_path, config=self.resolver_configurations ) # validate model id and load cache if self.resolver_configurations["embedder_model_id"] != self._embedder_model.model_id: msg = f"Unable to resolve the embedder model configurations. Found mismatched " \ f"configuartions between configs in the loaded pickle file and the configs " \ f"specified while instantiating {self.__class__.__name__}. Delete the related " \ f"model files and re-fit the resolver. Note that embedder models are not " \ f"pickled due to their large disk sizes and are only loaded from input configs." raise ValueError(msg) self._embedder_model.load_cache( cache_path=self.resolver_configurations["embedder_cache_path"] ) def _unload(self): self.processed_entity_map = None self._embedder_model = None def _predict_batch(self, nbest_entities_list, batch_size): # encode input entity top_entity_list = [i[0].text for i in nbest_entities_list] # top_entity try: # w/o batch, [ nsyms x 768*4 ] x [ 1 x 768*4 ] --> [ nsyms x 1 ] # w/ batch, [ nsyms x 768*4 ] x [ k x 768*4 ] --> [ nsyms x k ] scored_items_list = [] for st_idx in trange(0, len(top_entity_list), batch_size, disable=False): batch = top_entity_list[st_idx:st_idx + batch_size] result = self._embedder_model.find_similarity(batch, _no_sort=True) scored_items_list.extend(result) values_list = [] for scored_items in scored_items_list: values = [] for synonym, score in scored_items: cnames = self.processed_entity_map["synonyms"][synonym] for cname in cnames: for item in self.processed_entity_map["items"][cname]: item_value = copy.copy(item) item_value.pop("whitelist", None) item_value.update({"score": score}) item_value.update({"top_synonym": synonym}) values.append(item_value) values_list.append(values) except (KeyError, TypeError) as e: logger.error(e) return None return values_list
[docs] def predict_batch(self, entity_list, top_n: int = DEFAULT_TOP_N, batch_size: int = 8): if self._no_trainable_canonical_entity_map: return [[] for _ in entity_list] nbest_entities_list = [] results_list = [] for entity in entity_list: if isinstance(entity, (list, tuple)): top_entity = entity[0] nbest_entities = tuple(entity) else: top_entity = entity nbest_entities = tuple([entity]) nbest_entities_list.append(nbest_entities) if self._is_system_entity: # system entities are already resolved results_list.append(top_entity.value) if self._is_system_entity: return results_list results_list = self._predict_batch(nbest_entities_list, batch_size) return [self._trim_and_sort_results(results, top_n) for results in results_list]
[docs] def load_deprecated(self): self.fit()
[docs]class SentenceBertCosSimEntityResolver(EmbedderCosSimEntityResolver): """ Resolver class for bert models based on the sentence-transformers library https://github.com/UKPLab/sentence-transformers """ def __init__(self, app_path, entity_type, **kwargs): """ This wrapper class allows creation of a BERT base embedder class (currently based on sentence-transformers) Specificall, this wrapper updates er_config in kwargs with - any default settings if unavailable in input - cache path Args: app_path (str): App's path to cache embeddings er_config (dict): Configurations can be passed in through `model_settings` field `model_settings`: embedder_type: the type of embedder picked from embedder_models.py class (eg. 'bert', 'glove', etc. ) pretrained_name_or_abspath: the pretrained model for 'bert' embedder bert_output_type: if the output is a sentence mean pool or CLS output quantize_model: if the model needs to be quantized for faster inference time but at a possibly reduced accuracy concat_last_n_layers: if some of the last layers of a BERT model are to be concatenated for better accuracies normalize_token_embs: if the obtained sub-token level encodings are to be normalized """ # default configs useful for reusing model's encodings through a cache path defaults = { "embedder_type": "bert", "pretrained_name_or_abspath": "sentence-transformers/all-mpnet-base-v2", "bert_output_type": "mean", "quantize_model": True, "concat_last_n_layers": 1, "normalize_token_embs": False, } # update er_configs in the kwargs with the defaults if any of the default keys are missing kwargs.update({ "config": { **kwargs.get("config", {}), "model_settings": { **defaults, **kwargs.get("config", {}).get("model_settings", {}), }, } }) super().__init__(app_path, entity_type, **kwargs)
[docs]class EntityResolver: """ Class for backwards compatibility deprecated usage >>> entity_resolver = EntityResolver( app_path, resource_loader, entity_type ) new usage >>> entity_resolver = EntityResolverFactory.create_resolver( app_path, entity_type ) # or ... >>> entity_resolver = EntityResolverFactory.create_resolver( app_path, entity_type, resource_loader=resource_loader ) """ def __new__(cls, app_path, resource_loader, entity_type, **kwargs): msg = "Entity Resolver should now be loaded using EntityResolverFactory. " \ "See https://www.mindmeld.com/docs/userguide/entity_resolver.html for more details." warnings.warn(msg, DeprecationWarning) return EntityResolverFactory.create_resolver( app_path, entity_type, resource_loader=resource_loader, **kwargs )
ENTITY_RESOLVER_MODEL_MAPPINGS = { "exact_match": ExactMatchEntityResolver, "text_relevance": ElasticsearchEntityResolver, # TODO: In the newly added resolvers, to support # (1) using all provided entities (i.e all nbest_entities) like elastic search # (2) using kb_index_name and kb_field_name as used by Elasticsearch resolver "sbert_cosine_similarity": SentenceBertCosSimEntityResolver, "tfidf_cosine_similarity": TfIdfSparseCosSimEntityResolver, "embedder_cosine_similarity": EmbedderCosSimEntityResolver, }