# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module contains the role classifier component of the MindMeld natural language processor.
"""
import logging
import pickle
import joblib
from ._config import get_classifier_config
from .classifier import Classifier, ClassifierConfig, ClassifierLoadError
from ..constants import DEFAULT_TRAIN_SET_REGEX
from ..core import Query
from ..models import CLASS_LABEL_TYPE, ENTITY_EXAMPLE_TYPE, create_model, load_model
from ..resource_loader import ProcessedQueryList
logger = logging.getLogger(__name__)
[docs]class RoleClassifier(Classifier):
"""A role classifier is used to determine the target role for entities in a given query. It is
trained using all the labeled queries for a particular intent. The labels are the role names
associated with each entity within each query.
Attributes:
domain (str): The domain that this role classifier belongs to
intent (str): The intent that this role classifier belongs to
entity_type (str): The entity type that this role classifier is for
roles (set): A set containing the roles which can be classified
"""
CLF_TYPE = "role"
def __init__(self, resource_loader, domain, intent, entity_type):
"""Initializes a role classifier
Args:
resource_loader (ResourceLoader): An object which can load resources for the classifier
domain (str): The domain that this role classifier belongs to
intent (str): The intent that this role classifier belongs to
entity_type (str): The entity type that this role classifier is for
"""
super().__init__(resource_loader)
self.domain = domain
self.intent = intent
self.entity_type = entity_type
self.roles = set()
# pylint: disable=arguments-differ
def _get_model_config(self, **kwargs):
"""Gets a machine learning model configuration
Returns:
ModelConfig: The model configuration corresponding to the provided config name
"""
kwargs["example_type"] = ENTITY_EXAMPLE_TYPE
kwargs["label_type"] = CLASS_LABEL_TYPE
loaded_config = get_classifier_config(
self.CLF_TYPE,
self._resource_loader.app_path,
domain=self.domain,
intent=self.intent,
entity=self.entity_type,
)
return super()._get_model_config(loaded_config, **kwargs)
[docs] def fit(self,
queries=None,
label_set=None,
incremental_timestamp=None,
load_cached=True, **kwargs):
"""Trains a statistical model for role classification using the provided training examples.
Args:
queries (list of ProcessedQuery): The labeled queries to use as training data
label_set (list, optional): A label set to load. If not specified, the default
training set will be loaded.
incremental_timestamp (str, optional): The timestamp folder to cache models in
"""
logger.info(
"Fitting role classifier: domain=%r, intent=%r, entity_type=%r",
self.domain,
self.intent,
self.entity_type,
)
# create model with given params
model_config = self._get_model_config(**kwargs)
label_set = label_set or model_config.train_label_set or DEFAULT_TRAIN_SET_REGEX
queries = self._resolve_queries(queries, label_set)
new_hash = self._get_model_hash(model_config, queries)
cached_model_path = self._resource_loader.hash_to_model_path.get(new_hash)
# These examples and labels are flat lists, not
# a ProcessedQueryList.Iterator
examples, labels = self._get_examples_and_labels(queries)
if examples:
# Build roles set
self.roles.update(labels)
if incremental_timestamp and cached_model_path:
logger.info("No need to fit. Previous model is cached.")
if load_cached:
# load() sets self.ready = True
self.load(cached_model_path)
return True
return False
if self.roles:
model = create_model(model_config)
model.initialize_resources(self._resource_loader, examples, labels)
model.fit(examples, labels)
self._model = model
self.config = ClassifierConfig.from_model_config(self._model.config)
self.hash = new_hash
self.ready = True
self.dirty = True
return True
[docs] def dump(self, model_path, incremental_model_path=None):
"""Persists the trained role classification model to disk.
Args:
model_path (str): The model path.
incremental_model_path (str, Optional): The timestamped folder where the cached \
models are stored.
"""
logger.info(
"Saving role classifier: domain=%r, intent=%r, entity=%r",
self.domain,
self.intent,
self.entity_type,
)
super().dump(model_path, incremental_model_path)
def _dump(self, path):
rc_data = {"roles": self.roles}
pickle.dump(rc_data, open(self._get_classifier_resources_save_path(path), "wb"))
[docs] def unload(self):
self._model = None
self.roles = set()
self.ready = False
[docs] def load(self, model_path):
"""Loads the trained role classification model from disk.
Args:
model_path (str): The location on disk where the model is stored
"""
logger.info(
"Loading role classifier: domain=%r, intent=%r, entity_type=%r",
self.domain,
self.intent,
self.entity_type,
)
# underlying model specific load
self._model = load_model(model_path)
# classifier specific load
try:
rc_data = pickle.load(open(self._get_classifier_resources_save_path(model_path), "rb"))
except FileNotFoundError: # backwards compatability for previous version's saved models
rc_data = joblib.load(model_path)
self.roles = rc_data["roles"]
# validate and register resources
if self._model is not None:
if not hasattr(self._model, "mindmeld_version"):
msg = (
"Your trained models are incompatible with this version of MindMeld. "
"Please run a clean build to retrain models"
)
raise ClassifierLoadError(msg)
try:
self._model.config.to_dict()
except AttributeError:
# Loaded model config is incompatible with app config.
self._model.config.resolve_config(self._get_model_config())
gazetteers = self._resource_loader.get_gazetteers()
text_preparation_pipeline = self._resource_loader.get_text_preparation_pipeline()
self._model.register_resources(
gazetteers=gazetteers,
text_preparation_pipeline=text_preparation_pipeline
)
self.config = ClassifierConfig.from_model_config(self._model.config)
self.hash = self._load_hash(model_path)
self.ready = True
self.dirty = False
[docs] def predict(
self, query, entities, entity_index
): # pylint: disable=arguments-differ
"""Predicts a role for the given entity using the trained role classification model.
Args:
query (Query): The input query
entities (list): The entities in the query
entity_index (int): The index of the entity whose role should be classified
Returns:
str: The predicted role for the provided entity
"""
if not self._model:
logger.error("You must fit or load the model before running predict")
return
if len(self.roles) == 1:
return list(self.roles)[0]
if not isinstance(query, Query):
query = self._resource_loader.query_factory.create_query(query)
gazetteers = self._resource_loader.get_gazetteers()
text_preparation_pipeline = self._resource_loader.get_text_preparation_pipeline()
self._model.register_resources(
gazetteers=gazetteers,
text_preparation_pipeline=text_preparation_pipeline
)
return self._model.predict([(query, entities, entity_index)])[0]
[docs] def predict_proba(
self, query, entities, entity_index
): # pylint: disable=arguments-differ
"""Runs prediction on a given entity and generates multiple role hypotheses and
associated probabilities using the trained role classification model.
Args:
query (Query): The input query
entities (list): The entities in the query
entity_index (int): The index of the entity whose role should be classified
Returns:
list: a list of tuples of the form (str, float) grouping roles and their probabilities
"""
if not self._model:
logger.error("You must fit or load the model before running predict")
return
if len(self.roles) == 1:
return [(list(self.roles)[0], 1.0)]
if not isinstance(query, Query):
query = self._resource_loader.query_factory.create_query(query)
gazetteers = self._resource_loader.get_gazetteers()
text_preparation_pipeline = self._resource_loader.get_text_preparation_pipeline()
self._model.register_resources(
gazetteers=gazetteers,
text_preparation_pipeline=text_preparation_pipeline
)
predict_proba_result = self._model.predict_proba(
[(query, entities, entity_index)]
)
class_proba_tuples = list(predict_proba_result[0][1].items())
return sorted(class_proba_tuples, key=lambda x: x[1], reverse=True)
# pylint: disable=arguments-differ
def _get_queries_from_label_set(self, label_set=DEFAULT_TRAIN_SET_REGEX):
return self._resource_loader.get_flattened_label_set(
domain=self.domain,
intent=self.intent,
label_set=label_set
)
def _get_examples_and_labels(self, queries):
"""Returns a set of queries and their labels based on the label set
Args:
queries (list): A list of ProcessedQuery objects, to
train on. If not specified, a label set will be loaded.
Returns:
tuple(list(Any), list(Any))
"""
# build list of examples -- entities of this role classifier's type
examples = []
labels = []
for query in queries.processed_queries():
for idx, entity in enumerate(query.entities):
if entity.entity.type == self.entity_type and entity.entity.role:
examples.append((query.query, query.entities, idx))
labels.append(entity.entity.role)
unique_labels = set(labels)
if len(unique_labels) == 0:
# No roles
return (), ()
if None in unique_labels:
bad_examples = [e for i, e in enumerate(examples) if labels[i] is None]
for example in bad_examples:
logger.error(
"Invalid entity annotation, expecting role in query %r", example[0]
)
raise ValueError("One or more invalid entity annotations, expecting role")
return (ProcessedQueryList.ListIterator(examples),
ProcessedQueryList.ListIterator(labels))
def _get_examples_and_labels_hash(self, queries):
hashable_queries = (
[self.domain + "###" + self.intent + "###" + self.entity_type + "###"]
+ sorted(list(queries.raw_queries()))
)
return self._resource_loader.hash_list(hashable_queries)
[docs] def inspect(self, query, gold_label=None, dynamic_resource=None):
del gold_label
del dynamic_resource
del query
logger.warning("method not implemented")