Source code for mindmeld.components.domain_classifier

# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This module contains the domain classifier component of the MindMeld natural language processor.
"""
import logging

from ..constants import DEFAULT_TRAIN_SET_REGEX
from ..markup import mark_down
from ..models import CLASS_LABEL_TYPE, QUERY_EXAMPLE_TYPE
from ._config import get_classifier_config
from .classifier import Classifier

logger = logging.getLogger(__name__)


[docs]class DomainClassifier(Classifier): """A domain classifier is used to determine the target domain for a given query. It is trained using all of the labeled queries across all intents for all domains in an application. The labels for the training data are the domain names associated with each query. """ CLF_TYPE = "domain" """The classifier type.""" def _get_model_config(self, **kwargs): """Gets a machine learning model configuration Returns: ModelConfig: The model configuration corresponding to the provided config name """ kwargs["example_type"] = QUERY_EXAMPLE_TYPE kwargs["label_type"] = CLASS_LABEL_TYPE loaded_config = get_classifier_config( self.CLF_TYPE, self._resource_loader.app_path ) return super()._get_model_config(loaded_config, **kwargs)
[docs] def fit(self, *args, **kwargs): # pylint: disable=signature-differs """Trains the domain classification model using the provided training queries Args: model_type (str): The type of machine learning model to use. If omitted, the default model type will be used. features (dict): Features to extract from each example instance to form the feature vector used for model training. If omitted, the default feature set for the model type will be used. params_grid (dict): The grid of hyper-parameters to search, for finding the optimal hyper-parameter settings for the model. If omitted, the default hyper-parameter search grid will be used. cv (None, optional): Cross-validation settings queries (list of ProcessedQuery): The labeled queries to use as training data """ logger.info("Fitting domain classifier") return super().fit(*args, **kwargs)
[docs] def dump(self, *args, **kwargs): # pylint: disable=signature-differs """Persists the trained domain classification model to disk. Args: model_path (str): The location on disk where the model should be stored """ logger.info("Saving domain classifier") super().dump(*args, **kwargs)
[docs] def unload(self): logger.info("Unloading domain classifier") super().unload()
[docs] def load(self, *args, **kwargs): """Loads the trained domain classification model from disk Args: model_path (str): The location on disk where the model is stored """ logger.info("Loading domain classifier") super().load(*args, **kwargs)
[docs] def inspect(self, query, domain=None, dynamic_resource=None): """Inspects the query. Args: query (Query): The query to be predicted. domain (str): The expected domain label for this query. dynamic_resource (dict, optional): A dynamic resource to aid NLP inference. Returns: (list of lists): 2D list that includes every feature, their value, weight and \ probability. """ return self._model.inspect( example=query, gold_label=domain, dynamic_resource=dynamic_resource )
def _get_queries_from_label_set(self, label_set=DEFAULT_TRAIN_SET_REGEX): return self._resource_loader.get_flattened_label_set(label_set=label_set) def _get_examples_and_labels(self, queries): return (queries.queries(), queries.domains()) def _get_examples_and_labels_hash(self, queries): raw_queries = [] for domain, raw_query in zip(queries.domains(), queries.raw_queries()): raw_queries.append(domain + "###" + mark_down(raw_query)) raw_queries.sort() return self._resource_loader.hash_list(raw_queries)