Source code for mindmeld.components.entity_recognizer
# -*- coding: utf-8 -*-## Copyright (c) 2015 Cisco Systems, Inc. and others. All rights reserved.# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at# http://www.apache.org/licenses/LICENSE-2.0# Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""This module contains the entity recognizer component of the MindMeld natural language processor."""importloggingimportpickleimportjoblibfrom._configimportget_classifier_configfrom.classifierimportClassifier,ClassifierConfig,ClassifierLoadErrorfrom..constantsimportDEFAULT_TRAIN_SET_REGEXfrom..coreimportEntity,Queryfrom..modelsimportENTITIES_LABEL_TYPE,QUERY_EXAMPLE_TYPE,create_model,load_modellogger=logging.getLogger(__name__)
[docs]classEntityRecognizer(Classifier):"""An entity recognizer which is used to identify the entities for a given query. It is trained using all the labeled queries for a particular intent. The labels are the entity annotations for each query. Attributes: domain (str): The domain that this entity recognizer belongs to intent (str): The intent that this entity recognizer belongs to entity_types (set): A set containing the entity types which can be recognized """CLF_TYPE="entity""""The classifier type."""def__init__(self,resource_loader,domain,intent):"""Initializes an entity recognizer Args: resource_loader (ResourceLoader): An object which can load resources for the classifier domain (str): The domain that this entity recognizer belongs to intent (str): The intent that this entity recognizer belongs to """super().__init__(resource_loader)self.domain=domainself.intent=intentself.entity_types=set()# TODO: Deprecate the var self._model_config as the configs are already dumped by modelsself._model_config=Nonedef_get_model_config(self,**kwargs):# pylint: disable=arguments-differ"""Gets a machine learning model configuration Returns: ModelConfig: The model configuration corresponding to the provided config name """kwargs["example_type"]=QUERY_EXAMPLE_TYPEkwargs["label_type"]=ENTITIES_LABEL_TYPEloaded_config=get_classifier_config(self.CLF_TYPE,self._resource_loader.app_path,domain=self.domain,intent=self.intent,)returnsuper()._get_model_config(loaded_config,**kwargs)
[docs]deffit(self,queries=None,label_set=None,incremental_timestamp=None,load_cached=True,**kwargs):logger.info("Fitting entity recognizer: domain=%r, intent=%r",self.domain,self.intent)# create model with given paramsself._model_config=self._get_model_config(**kwargs)label_set=label_setorself._model_config.train_label_setorDEFAULT_TRAIN_SET_REGEXqueries=self._resolve_queries(queries,label_set)new_hash=self._get_model_hash(self._model_config,queries)cached_model_path=self._resource_loader.hash_to_model_path.get(new_hash)# After PR 356, entity.pkl file is not created when there are no entity types,# similar to not having domain.pkl or intent.pkl when there are less than 2 domains# or 2 intents respectively.# Before this PR, not doing this dump leads to `cached_model_path=None` in above line. After# this PR, this will be set to `cached_model_path=<>.pkl` path and the self.load() takes# care of loading a NoneType model. Had it been `cached_model_path=None` like previously,# the following code skips the `load_cached` check and directly attempts to create a new# model. This is not an issue in domain and intent classifiers as the .fit() method is not# called when there are less than 2 domains/intents.# Load labeled dataexamples,labels=self._get_examples_and_labels(queries)ifexamples:# Build entity types setself.entity_types={entity.entity.typeforlabelinlabelsforentityinlabel}ifincremental_timestampandcached_model_path:logger.info("No need to fit. Previous model is cached.")ifload_cached:self.load(cached_model_path)returnTruereturnFalseifself.entity_types:model=create_model(self._model_config)model.initialize_resources(self._resource_loader,examples,labels)model.fit(examples,labels)self._model=modelself.config=ClassifierConfig.from_model_config(self._model.config)self.hash=new_hashself.ready=Trueself.dirty=TruereturnTrue
[docs]defdump(self,model_path,incremental_model_path=None):"""Save the model. Args: model_path (str): The model path. incremental_model_path (str, Optional): The timestamped folder where the cached \ models are stored. """logger.info("Saving entity classifier: domain=%r, intent=%r",self.domain,self.intent)super().dump(model_path,incremental_model_path)
[docs]defload(self,model_path):"""Loads the trained entity recognition model from disk. Args: model_path (str): The location on disk where the model is stored. """logger.info("Loading entity recognizer: domain=%r, intent=%r",self.domain,self.intent)# underlying model specific loadself._model=load_model(model_path)# classifier specific loadtry:er_data=pickle.load(open(self._get_classifier_resources_save_path(model_path),"rb"))exceptFileNotFoundError:# backwards compatability for previous version's saved modelser_data=joblib.load(model_path)self.entity_types=er_data["entity_types"]self._model_config=er_data["model_config"]# validate and register resourcesifself._modelisnotNone:ifnothasattr(self._model,"mindmeld_version"):msg=("Your trained models are incompatible with this version of MindMeld. ""Please run a clean build to retrain models")raiseClassifierLoadError(msg)try:self._model.config.to_dict()exceptAttributeError:# Loaded model config is incompatible with app config.self._model.config.resolve_config(self._get_model_config())gazetteers=self._resource_loader.get_gazetteers()text_preparation_pipeline=self._resource_loader.get_text_preparation_pipeline()sys_types=set((tfortinself.entity_typesifEntity.is_system_entity(t)))w_ngram_freq=er_data.get("w_ngram_freq")c_ngram_freq=er_data.get("c_ngram_freq")self._model.register_resources(gazetteers=gazetteers,sys_types=sys_types,w_ngram_freq=w_ngram_freq,c_ngram_freq=c_ngram_freq,text_preparation_pipeline=text_preparation_pipeline,)self.config=ClassifierConfig.from_model_config(self._model.config)self.hash=self._load_hash(model_path)self.ready=Trueself.dirty=False
[docs]defpredict(self,query,time_zone=None,timestamp=None,dynamic_resource=None):"""Predicts entities for the given query using the trained recognition model. Args: query (Query, str): The input query. time_zone (str, optional): The name of an IANA time zone, such as 'America/Los_Angeles', or 'Asia/Kolkata' See the [tz database](https://www.iana.org/time-zones) for more information. timestamp (long, optional): A unix time stamp for the request (in seconds). dynamic_resource (dict, optional): A dynamic resource to aid NLP inference. Returns: (str): The predicted class label. """prediction=(super().predict(query,time_zone=time_zone,timestamp=timestamp,dynamic_resource=dynamic_resource,)or())returntuple(sorted(prediction,key=lambdae:e.span.start))
[docs]defpredict_proba(self,query,time_zone=None,timestamp=None,dynamic_resource=None):"""Runs prediction on a given query and generates multiple entity tagging hypotheses with their associated probabilities using the trained entity recognition model Args: query (Query, str): The input query. time_zone (str, optional): The name of an IANA time zone, such as 'America/Los_Angeles', or 'Asia/Kolkata' See the [tz database](https://www.iana.org/time-zones) for more information. timestamp (long, optional): A unix time stamp for the request (in seconds). dynamic_resource (optional): Dynamic resource, unused. Returns: (list): A list of tuples of the form (Entity list, float) grouping potential entity \ tagging hypotheses and their probabilities. """deldynamic_resourceifnotself._model:logger.error("You must fit or load the model before running predict_proba")return[]ifnotisinstance(query,Query):query=self._resource_loader.query_factory.create_query(query,time_zone=time_zone,timestamp=timestamp)predict_proba_result=self._model.predict_proba([query])returnpredict_proba_result