# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the entity recognizer component of the MindMeld natural language processor.
"""
import logging
import pickle
import joblib
from ._config import get_classifier_config
from .classifier import Classifier, ClassifierConfig, ClassifierLoadError
from ..constants import DEFAULT_TRAIN_SET_REGEX
from ..core import Entity, Query
from ..models import ENTITIES_LABEL_TYPE, QUERY_EXAMPLE_TYPE, create_model, load_model
logger = logging.getLogger(__name__)
[docs]class EntityRecognizer(Classifier):
"""An entity recognizer which is used to identify the entities for a given query. It is
trained using all the labeled queries for a particular intent. The labels are the entity
annotations for each query.
Attributes:
domain (str): The domain that this entity recognizer belongs to
intent (str): The intent that this entity recognizer belongs to
entity_types (set): A set containing the entity types which can be recognized
"""
CLF_TYPE = "entity"
"""The classifier type."""
def __init__(self, resource_loader, domain, intent):
"""Initializes an entity recognizer
Args:
resource_loader (ResourceLoader): An object which can load resources for the classifier
domain (str): The domain that this entity recognizer belongs to
intent (str): The intent that this entity recognizer belongs to
"""
super().__init__(resource_loader)
self.domain = domain
self.intent = intent
self.entity_types = set()
# TODO: Deprecate the var self._model_config as the configs are already dumped by models
self._model_config = None
def _get_model_config(self, **kwargs): # pylint: disable=arguments-differ
"""Gets a machine learning model configuration
Returns:
ModelConfig: The model configuration corresponding to the provided config name
"""
kwargs["example_type"] = QUERY_EXAMPLE_TYPE
kwargs["label_type"] = ENTITIES_LABEL_TYPE
loaded_config = get_classifier_config(
self.CLF_TYPE,
self._resource_loader.app_path,
domain=self.domain,
intent=self.intent,
)
return super()._get_model_config(loaded_config, **kwargs)
[docs] def get_entity_types(self, queries=None, label_set=None, **kwargs):
if not label_set:
label_set = self._get_model_config(**kwargs).train_label_set
label_set = label_set if label_set else DEFAULT_TRAIN_SET_REGEX
# Load labeled data
queries = self._resolve_queries(queries, label_set)
queries, labels = self._get_examples_and_labels(queries)
# Build entity types set
entity_types = set()
for label in labels:
for entity in label:
entity_types.add(entity.entity.type)
return entity_types
[docs] def fit(self,
queries=None,
label_set=None,
incremental_timestamp=None,
load_cached=True,
**kwargs):
logger.info(
"Fitting entity recognizer: domain=%r, intent=%r", self.domain, self.intent
)
# create model with given params
self._model_config = self._get_model_config(**kwargs)
label_set = label_set or self._model_config.train_label_set or DEFAULT_TRAIN_SET_REGEX
queries = self._resolve_queries(queries, label_set)
new_hash = self._get_model_hash(self._model_config, queries)
cached_model_path = self._resource_loader.hash_to_model_path.get(new_hash)
# After PR 356, entity.pkl file is not created when there are no entity types,
# similar to not having domain.pkl or intent.pkl when there are less than 2 domains
# or 2 intents respectively.
# Before this PR, not doing this dump leads to `cached_model_path=None` in above line. After
# this PR, this will be set to `cached_model_path=<>.pkl` path and the self.load() takes
# care of loading a NoneType model. Had it been `cached_model_path=None` like previously,
# the following code skips the `load_cached` check and directly attempts to create a new
# model. This is not an issue in domain and intent classifiers as the .fit() method is not
# called when there are less than 2 domains/intents.
# Load labeled data
examples, labels = self._get_examples_and_labels(queries)
if examples:
# Build entity types set
self.entity_types = {entity.entity.type for label in labels for entity in label}
if incremental_timestamp and cached_model_path:
logger.info("No need to fit. Previous model is cached.")
if load_cached:
self.load(cached_model_path)
return True
return False
if self.entity_types:
model = create_model(self._model_config)
model.initialize_resources(self._resource_loader, examples, labels)
model.fit(examples, labels)
self._model = model
self.config = ClassifierConfig.from_model_config(self._model.config)
self.hash = new_hash
self.ready = True
self.dirty = True
return True
[docs] def dump(self, model_path, incremental_model_path=None):
"""Save the model.
Args:
model_path (str): The model path.
incremental_model_path (str, Optional): The timestamped folder where the cached \
models are stored.
"""
logger.info(
"Saving entity classifier: domain=%r, intent=%r", self.domain, self.intent
)
super().dump(model_path, incremental_model_path)
def _dump(self, path):
er_data = {
"entity_types": self.entity_types,
"model_config": self._model_config,
}
if self._model:
er_data.update({
"w_ngram_freq": self._model.get_resource("w_ngram_freq"),
"c_ngram_freq": self._model.get_resource("c_ngram_freq"),
})
pickle.dump(er_data, open(self._get_classifier_resources_save_path(path), "wb"))
[docs] def unload(self):
logger.info(
"Unloading entity recognizer: domain=%r, intent=%r", self.domain, self.intent
)
self.entity_types = None
self._model_config = None
self._model = None
self.ready = False
[docs] def load(self, model_path):
"""Loads the trained entity recognition model from disk.
Args:
model_path (str): The location on disk where the model is stored.
"""
logger.info(
"Loading entity recognizer: domain=%r, intent=%r", self.domain, self.intent
)
# underlying model specific load
self._model = load_model(model_path)
# classifier specific load
try:
er_data = pickle.load(open(self._get_classifier_resources_save_path(model_path), "rb"))
except FileNotFoundError: # backwards compatability for previous version's saved models
er_data = joblib.load(model_path)
self.entity_types = er_data["entity_types"]
self._model_config = er_data["model_config"]
# validate and register resources
if self._model is not None:
if not hasattr(self._model, "mindmeld_version"):
msg = (
"Your trained models are incompatible with this version of MindMeld. "
"Please run a clean build to retrain models"
)
raise ClassifierLoadError(msg)
try:
self._model.config.to_dict()
except AttributeError:
# Loaded model config is incompatible with app config.
self._model.config.resolve_config(self._get_model_config())
gazetteers = self._resource_loader.get_gazetteers()
text_preparation_pipeline = self._resource_loader.get_text_preparation_pipeline()
sys_types = set(
(t for t in self.entity_types if Entity.is_system_entity(t))
)
w_ngram_freq = er_data.get("w_ngram_freq")
c_ngram_freq = er_data.get("c_ngram_freq")
self._model.register_resources(
gazetteers=gazetteers,
sys_types=sys_types,
w_ngram_freq=w_ngram_freq,
c_ngram_freq=c_ngram_freq,
text_preparation_pipeline=text_preparation_pipeline,
)
self.config = ClassifierConfig.from_model_config(self._model.config)
self.hash = self._load_hash(model_path)
self.ready = True
self.dirty = False
[docs] def predict(self, query, time_zone=None, timestamp=None, dynamic_resource=None):
"""Predicts entities for the given query using the trained recognition model.
Args:
query (Query, str): The input query.
time_zone (str, optional): The name of an IANA time zone, such as
'America/Los_Angeles', or 'Asia/Kolkata'
See the [tz database](https://www.iana.org/time-zones) for more information.
timestamp (long, optional): A unix time stamp for the request (in seconds).
dynamic_resource (dict, optional): A dynamic resource to aid NLP inference.
Returns:
(str): The predicted class label.
"""
prediction = (
super().predict(
query,
time_zone=time_zone,
timestamp=timestamp,
dynamic_resource=dynamic_resource,
)
or ()
)
return tuple(sorted(prediction, key=lambda e: e.span.start))
[docs] def predict_proba(
self, query, time_zone=None, timestamp=None, dynamic_resource=None
):
"""Runs prediction on a given query and generates multiple entity tagging hypotheses with
their associated probabilities using the trained entity recognition model
Args:
query (Query, str): The input query.
time_zone (str, optional): The name of an IANA time zone, such as
'America/Los_Angeles', or 'Asia/Kolkata'
See the [tz database](https://www.iana.org/time-zones) for more information.
timestamp (long, optional): A unix time stamp for the request (in seconds).
dynamic_resource (optional): Dynamic resource, unused.
Returns:
(list): A list of tuples of the form (Entity list, float) grouping potential entity \
tagging hypotheses and their probabilities.
"""
del dynamic_resource
if not self._model:
logger.error("You must fit or load the model before running predict_proba")
return []
if not isinstance(query, Query):
query = self._resource_loader.query_factory.create_query(
query, time_zone=time_zone, timestamp=timestamp
)
predict_proba_result = self._model.predict_proba([query])
return predict_proba_result
def _get_queries_from_label_set(self, label_set=DEFAULT_TRAIN_SET_REGEX):
return self._resource_loader.get_flattened_label_set(
domain=self.domain,
intent=self.intent,
label_set=label_set
)
def _get_examples_and_labels(self, queries):
return (queries.queries(), queries.entities())
def _get_examples_and_labels_hash(self, queries):
hashable_queries = (
[self.domain + "###" + self.intent + "###entity###"]
+ sorted(list(queries.raw_queries()))
)
return self._resource_loader.hash_list(hashable_queries)
[docs] def inspect(self, query, gold_label=None, dynamic_resource=None):
del query
del gold_label
del dynamic_resource
logger.warning("method not implemented")