# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base classes for models defined in the models subpackage."""
import copy
import json
import logging
import math
import os
import pickle
from abc import ABC, abstractmethod
from inspect import signature
from typing import Union, Type, Dict, Any, Tuple, List, Pattern, Set
import joblib
from sklearn.model_selection import (
GridSearchCV,
GroupKFold,
GroupShuffleSplit,
KFold,
ShuffleSplit,
StratifiedKFold,
StratifiedShuffleSplit,
)
from sklearn.preprocessing import LabelEncoder as SKLabelEncoder
from ._util import _is_module_available
from .evaluation import EntityModelEvaluation, StandardModelEvaluation
from .helpers import (
CHAR_NGRAM_FREQ_RSC,
ENABLE_STEMMING,
GAZETTEER_RSC,
WORD_NGRAM_FREQ_RSC,
QUERY_FREQ_RSC,
SYS_TYPES_RSC,
WORD_FREQ_RSC,
SENTIMENT_ANALYZER,
get_feature_extractor,
get_label_encoder,
ingest_dynamic_gazetteer,
)
from .nn_utils.helpers import EmbedderType
from .._version import get_mm_version
from ..core import ProcessedQuery, QueryEntity
from ..resource_loader import ResourceLoader, ProcessedQueryList as PQL
from ..text_preparation.text_preparation_pipeline import (
TextPreparationPipelineFactory,
TextPreparationPipeline
)
# for backwards compatability for sklearn models serialized and dumped in previous version
from .labels import LabelEncoder, EntityLabelEncoder # pylint: disable=unused-import
logger = logging.getLogger(__name__)
Examples = Union[PQL.QueryIterator, PQL.ListIterator]
Labels = Union[PQL.DomainIterator, PQL.IntentIterator, PQL.EntitiesIterator]
[docs]class ModelConfig:
"""A value object representing a model configuration.
Attributes:
model_type (str): The name of the model type. Will be used to find the
model class to instantiate
example_type (str): The type of the examples which will be passed into
`fit()` and `predict()`. Used to select feature extractors
label_type (str): The type of the labels which will be passed into
`fit()` and returned by `predict()`. Used to select the label encoder
model_settings (dict): Settings specific to the model type specified
params (dict): Params to pass to the underlying classifier
param_selection (dict): Configuration for param selection (using cross
validation)
{'type': 'shuffle',
'n': 3,
'k': 10,
'n_jobs': 2,
'scoring': '',
'grid': {}
}
features (dict): The keys are the names of feature extractors and the
values are either a kwargs dict which will be passed into the
feature extractor function, or a callable which will be used as to
extract features
train_label_set (regex pattern): The regex pattern for finding training
file names.
test_label_set (regex pattern): The regex pattern for finding testing
file names.
"""
__slots__ = [
"model_type",
"example_type",
"label_type",
"features",
"model_settings",
"params",
"param_selection",
"train_label_set",
"test_label_set",
]
def __init__(
self,
model_type: str = None,
example_type: str = None,
label_type: str = None,
features: Dict = None,
model_settings: Dict = None,
params: Dict = None,
param_selection: Dict = None,
train_label_set: Pattern[str] = None,
test_label_set: Pattern[str] = None,
):
for arg, val in {
"model_type": model_type,
"label_type": label_type,
}.items():
if val is None:
raise TypeError("__init__() missing required argument {!r}".format(arg))
self.model_type = model_type
self.example_type = example_type
self.label_type = label_type
self.features = features
self.model_settings = model_settings
self.params = params
self.param_selection = param_selection
self.train_label_set = train_label_set
self.test_label_set = test_label_set
def __repr__(self):
args_str = ", ".join(
"{}={!r}".format(key, getattr(self, key)) for key in self.__slots__
)
return "{}({})".format(self.__class__.__name__, args_str)
[docs] def to_dict(self) -> Dict:
"""Converts the model config object into a dict
Returns:
dict: A dict version of the config
"""
result = {}
for attr in self.__slots__:
result[attr] = getattr(self, attr)
return result
[docs] def to_json(self) -> str:
"""Converts the model config object to JSON
Returns:
str: JSON representation of the classifier
"""
return json.dumps(self.to_dict(), sort_keys=True)
[docs] def resolve_config(self, new_config: "ModelConfig"):
"""This method resolves any config incompatibility issues by
loading the latest settings from the app config to the current config
Args:
new_config (ModelConfig): The ModelConfig representing the app's latest config
"""
new_settings = ["train_label_set", "test_label_set"]
logger.warning(
"Loading missing properties %s from app " "configuration file", new_settings
)
for setting in new_settings:
setattr(self, setting, getattr(new_config, setting))
[docs] def get_ngram_lengths_and_thresholds(self, rname: str) -> Tuple:
"""
Returns the n-gram lengths and thresholds to extract to optimize resource collection
Args:
rname (string): Name of the resource
Returns:
(tuple): tuple containing:
* lengths (list of int): list of n-gram lengths to be extracted
* thresholds (list of int): thresholds to be applied to corresponding n-gram lengths
"""
lengths = thresholds = None
# if it's not the n-gram feature, we don't need length and threshold information
if rname == CHAR_NGRAM_FREQ_RSC:
feature_name = "char-ngrams"
elif rname == WORD_NGRAM_FREQ_RSC:
feature_name = "bag-of-words"
else:
return lengths, thresholds
# feature name varies based on whether it's for a classifier or tagger
if self.model_type == "text":
if feature_name in self.features:
lengths = self.features[feature_name]["lengths"]
thresholds = self.features[feature_name].get(
"thresholds", [1] * len(lengths)
)
elif self.model_type == "tagger":
feature_name = feature_name + "-seq"
if feature_name in self.features:
lengths = self.features[feature_name][
"ngram_lengths_to_start_positions"
].keys()
thresholds = self.features[feature_name].get(
"thresholds", [1] * len(lengths)
)
return lengths, thresholds
[docs] def required_resources(self) -> Set:
"""Returns the resources this model requires
Returns:
set: set of required resources for this model
"""
# get list of resources required by feature extractors
required_resources = set()
if self.features:
for name in self.features:
feature = get_feature_extractor(self.example_type, name)
required_resources.update(feature.__dict__.get("requirements", []))
return required_resources
[docs]class AbstractModel(ABC):
"""
A minimalistic abstract class upon which all models are based.
In order to maintain backwards compatability, the skeleton of this class is designed based on
all the access points of Classifier class and its sub-classes. In addition, it also introduces
the decoupled way of dumping/loading across different model types (meaning not all models are
dumped/loaded the same way). Furthermore, methods for validation are also introduced so as to
cater to the model specific config validations. Lastly, this skeleton also includes some
common properties that could be used across all model types.
"""
def __init__(self, config: ModelConfig):
self.config = config
self.mindmeld_version = get_mm_version()
self._resources = {}
self._validate_model_configs()
[docs] @abstractmethod
def initialize_resources(
self, resource_loader: ResourceLoader, examples: Examples = None, labels: Labels = None
):
raise NotImplementedError
[docs] @abstractmethod
def fit(self, examples: Examples, labels: Labels, params: Dict = None):
raise NotImplementedError
[docs] @abstractmethod
def predict(
self, examples: Examples, dynamic_resource: Dict = None
) -> Union[List[Any], List[List[Any]]]:
raise NotImplementedError
[docs] @abstractmethod
def predict_proba(
self, examples: Examples, dynamic_resource: Dict = None
) -> Union[List[Tuple[str, Dict[str, float]]], Tuple[Tuple[QueryEntity, float]]]:
raise NotImplementedError
[docs] @abstractmethod
def evaluate(
self, examples: Examples, labels: Labels
) -> Union[StandardModelEvaluation, EntityModelEvaluation]:
raise NotImplementedError
[docs] @classmethod
@abstractmethod
def load(cls, path: str) -> Type["AbstractModel"]:
raise NotImplementedError
[docs] @classmethod
def load_model_config(cls, path: str) -> ModelConfig:
"""
Dumps the model's configs. Raises a FileNotFoundError if no configs file is found.
For backwards compatability wherein TextModel was serialized and dumped, the textModel file
is loaded using joblib and then the config is obtained from its public variables.
Args:
path (str): The path where the model is dumped
"""
try:
model_configs_save_path = cls._get_model_config_save_path(path)
model_config = pickle.load(open(model_configs_save_path, "rb"))
except FileNotFoundError as e: # backwards compatability for sklearn-based model classes
metadata = joblib.load(path)
# metadata here can be a serialized model (eg. TextModel) or a dict (eg. TaggerModel)
if isinstance(metadata, dict):
# compatability with previously dumped EntityRecognizers and RoleClassifiers
try:
# sklearn TaggerModel used by EntityRecognizer
model_config = metadata["model_config"]
except KeyError:
try:
# sklearn TextModel used by RoleClassifier (w/ a non-NoneType "model")
model_config = metadata["model"].config
except AttributeError:
# sklearn TextModel used by RoleClassifier (w/ a NoneType "model")
# in latest version, nothing gets dumped at the dump
# path: 'path/to/dump/<entity_name>-role.pkl' if the self._model in
# RoleClassifier is None because of a check in Classifier.dump()
# in previous version, a dictionary '{'model': None, 'roles': set()}'
# is dumped at the path: 'path/to/dump/<entity_name>-role.pkl'
# although the self._model in RoleClassifier is None
msg = f"Model config data cold not be identified from existing dump at " \
f"path: {path}. Assuming that the dumped model is NoneType and " \
f"belongs to a role classifier"
raise FileNotFoundError(msg) from e
else:
# compatability with previously dumped DomainClassifiers and IntentClassifiers
# in this case, metadata = model which was serialized and dumped
model_config = metadata.config
return model_config
[docs] def dump(self, path: str):
"""
Dumps the model's configs and calls the child model's dump method
Args:
path (str): The path to dump the model to
"""
# every subclass of ABCModel has one .pkl dump file that contains a
# dictionary with at least the key 'model_config' whose value is a model configs dict
# this pickle file is sought in `load_model_config()` method
self._dump_model_config(path)
# call this subclass-implemeneted method to allow models to dump in their own style
# (eg. serialized dump for sklearn-based models, .bin files for pytorch-based models, etc.)
self._dump(path)
def _dump_model_config(self, path: str):
"""
Dumps the model's configs
"""
model_configs_save_path = self._get_model_config_save_path(path)
pickle.dump(self.config, open(model_configs_save_path, "wb"))
@abstractmethod
def _dump(self, path: str):
"""
Dumps the model and calls the underlying algo to dump its state.
Args:
path (str): The path to dump the model to
"""
pass
@staticmethod
def _get_model_config_save_path(path: str) -> str:
head, ext = os.path.splitext(path)
model_config_save_path = head + ".config" + ext
os.makedirs(os.path.dirname(model_config_save_path), exist_ok=True)
return model_config_save_path
[docs] def register_resources(self, **kwargs): # pylint: disable=no-self-use
# Resources for feature extractors are not required for deep neural models
del kwargs
pass
[docs] def get_resource(self, name) -> Any:
return self._resources.get(name)
@property
def text_preparation_pipeline(self) -> TextPreparationPipeline:
text_preparation_pipeline = self._resources.get("text_preparation_pipeline")
if not text_preparation_pipeline:
logger.error(
"The text_preparation_pipeline resource has not been registered "
"to the model. Using default text_preparation_pipeline."
)
return TextPreparationPipelineFactory.create_default_text_preparation_pipeline()
return text_preparation_pipeline
def _validate_model_configs(self):
pass
[docs]class Model(AbstractModel):
"""An abstract class upon which all models are based.
Attributes:
config (ModelConfig): The configuration for the model
"""
# model scoring type
LIKELIHOOD_SCORING = "log_loss"
ALLOWED_CLASSIFIER_TYPES: List[str] = NotImplemented
def __init__(self, config):
super().__init__(config)
self._label_encoder = get_label_encoder(self.config)
self._current_params = None
self._clf = None
self.cv_loss_ = None
def _fit(self, examples, labels, params=None):
raise NotImplementedError
def _get_model_constructor(self):
raise NotImplementedError
def _fit_cv(self, examples, labels, groups=None, selection_settings=None, fixed_params=None):
"""Called by the fit method when cross validation parameters are passed in. Runs cross
validation and returns the best estimator and parameters.
Args:
examples (list): A list of examples. Should be in the format expected by the \
underlying estimator.
labels (list): The target output values.
groups (None, optional): Same length as examples and labels. Used to group \
examples when splitting the dataset into train/test
selection_settings (dict, optional): A dictionary containing the cross \
validation selection settings.
"""
selection_settings = selection_settings or self.config.param_selection
cv_iterator = self._get_cv_iterator(selection_settings)
if selection_settings is None:
return self._fit(examples, labels, self.config.params), self.config.params
cv_type = selection_settings["type"]
num_splits = cv_iterator.get_n_splits(examples, labels, groups)
logger.info(
"Selecting hyperparameters using %s cross-validation with %s split%s",
cv_type,
num_splits,
"" if num_splits == 1 else "s",
)
scoring = self._get_cv_scorer(selection_settings)
n_jobs = selection_settings.get("n_jobs", -1)
param_grid = self._convert_params(selection_settings["grid"], labels)
model_class = self._get_model_constructor()
if fixed_params:
for key, val in fixed_params.items():
if key not in param_grid:
param_grid[key] = [val]
else:
logger.info(
"Found parameter %s both in params and param_selection. Proceeding with param_selection.. \
(If you did not set this, it could be a Mindmeld default.)",
key
)
estimator, param_grid = self._get_cv_estimator_and_params(
model_class, param_grid
)
# set GridSearchCV's return_train_score attribute to False improves cross-validation
# runtime perf as it doesn't have to compute training scores and which we don't consume
grid_cv = GridSearchCV(
estimator=estimator,
scoring=scoring,
param_grid=param_grid,
cv=cv_iterator,
n_jobs=n_jobs,
return_train_score=False,
)
model = grid_cv.fit(examples, y=labels, groups=groups)
for idx, params in enumerate(model.cv_results_["params"]):
logger.debug("Candidate parameters: %s", params)
std_err = (
2.0
* model.cv_results_["std_test_score"][idx]
/ math.sqrt(model.n_splits_)
)
if scoring == Model.LIKELIHOOD_SCORING:
msg = "Candidate average log likelihood: {:.4} ± {:.4}"
else:
msg = "Candidate average accuracy: {:.2%} ± {:.2%}"
# pylint: disable=logging-format-interpolation
logger.debug(msg.format(model.cv_results_["mean_test_score"][idx], std_err))
if scoring == Model.LIKELIHOOD_SCORING:
msg = "Best log likelihood: {:.4}, params: {}"
self.cv_loss_ = -model.best_score_
else:
msg = "Best accuracy: {:.2%}, params: {}"
self.cv_loss_ = 1 - model.best_score_
best_params = self._process_cv_best_params(model.best_params_)
# pylint: disable=logging-format-interpolation
logger.info(msg.format(model.best_score_, best_params))
return model.best_estimator_, model.best_params_
def _get_cv_scorer(self, selection_settings):
"""
Returns the scorer to use based on the selection settings and classifier type.
"""
raise NotImplementedError
@staticmethod
def _clean_params(model_class, params):
"""
Make sure that the params to be passed into model construction meet the model's expected
params
"""
expected_params = signature(model_class).parameters.keys()
if len(expected_params) == 1 and "parameters" in expected_params:
# if there is only one key, this is our custom TaggerModel which is supposed to be a
# pass-through and pass everything to the actual sklearn model
return params
result = copy.deepcopy(params)
for param in params:
if param not in expected_params:
msg = (
"Unexpected param `{param}`, dropping it from model config.".format(
param=param
)
)
logger.warning(msg)
result.pop(param)
return result
def _get_cv_estimator_and_params(self, model_class, param_grid):
param_grid = self._clean_params(model_class, param_grid)
if "warm_start" in signature(model_class).parameters.keys():
# Warm start helps speed up cross-validation for some models such as random forest
return model_class(warm_start=True), param_grid
else:
return model_class(), param_grid
@staticmethod
def _process_cv_best_params(best_params):
return best_params
[docs] def select_params(self, examples, labels, selection_settings=None):
"""Selects the best set of hyper-parameters for a given set of examples and true labels
through cross-validation
Args:
examples: A list of example queries
labels: A list of labels associated with the queries
selection_settings: A dictionary of parameter lists to select from
Returns:
dict: A dictionary of optimized parameters to use
"""
raise NotImplementedError
def _convert_params(self, param_grid, y, is_grid=True):
"""Convert the params from the style given by the config to the style
passed in to the actual classifier.
Args:
param_grid (dict): lists of classifier parameter values, keyed by \
parameter name
y (list): A list of labels
is_grid (bool, optional): Indicates whether param_grid is actually a grid \
or a params dict.
"""
raise NotImplementedError
def _get_effective_config(self):
"""Create a model config object for the current effective config (after \
param selection)
Returns:
ModelConfig
"""
config_dict = self.config.to_dict()
config_dict.pop("param_selection")
config_dict["params"] = self._current_params
return ModelConfig(**config_dict)
[docs] def register_resources(self, **kwargs):
"""Registers resources which are accessible to feature extractors
Args:
**kwargs: dictionary of resources to register
"""
self._resources.update(kwargs)
[docs] def get_feature_matrix(self, examples, y=None, fit=False):
raise NotImplementedError
def _extract_features(self, example, dynamic_resource=None, text_preparation_pipeline=None):
"""Gets all features from an example.
Args:
example: An example object.
dynamic_resource (dict, optional): A dynamic resource to aid NLP inference
text_preparation_pipeline (TextPreparationPipeline): MindMeld text processing object
Returns:
(dict of str: number): A dict of feature names to their values.
"""
example_type = self.config.example_type
feat_set = {}
workspace_resource = ingest_dynamic_gazetteer(
self._resources, dynamic_resource, text_preparation_pipeline
)
workspace_features = copy.deepcopy(self.config.features)
enable_stemming = workspace_features.pop(ENABLE_STEMMING, False)
for name, kwargs in workspace_features.items():
if callable(kwargs):
# a feature extractor function was passed in directly
feat_extractor = kwargs
else:
kwargs[ENABLE_STEMMING] = enable_stemming
feat_extractor = get_feature_extractor(example_type, name)(**kwargs)
feat_set.update(feat_extractor(example, workspace_resource))
return feat_set
def _get_cv_iterator(self, settings):
if not settings:
return None
cv_type = settings["type"]
try:
cv_iterator = {
"k-fold": self._k_fold_iterator,
"shuffle": self._shuffle_iterator,
"group-k-fold": self._groups_k_fold_iterator,
"group-shuffle": self._groups_shuffle_iterator,
"stratified-k-fold": self._stratified_k_fold_iterator,
"stratified-shuffle": self._stratified_shuffle_iterator,
}.get(cv_type)(settings)
except KeyError as e:
raise ValueError("Unknown param selection type: {!r}".format(cv_type)) from e
return cv_iterator
@staticmethod
def _k_fold_iterator(settings):
k = settings["k"]
return KFold(n_splits=k)
@staticmethod
def _shuffle_iterator(settings):
k = settings["k"]
n = settings.get("n", k)
test_size = 1.0 / k
return ShuffleSplit(n_splits=n, test_size=test_size)
@staticmethod
def _groups_k_fold_iterator(settings):
k = settings["k"]
return GroupKFold(n_splits=k)
@staticmethod
def _groups_shuffle_iterator(settings):
k = settings["k"]
n = settings.get("n", k)
test_size = 1.0 / k
return GroupShuffleSplit(n_splits=n, test_size=test_size)
@staticmethod
def _stratified_k_fold_iterator(settings):
k = settings["k"]
return StratifiedKFold(n_splits=k)
@staticmethod
def _stratified_shuffle_iterator(settings):
k = settings["k"]
n = settings.get("n", k)
test_size = 1.0 / k
return StratifiedShuffleSplit(n_splits=n, test_size=test_size)
[docs] def requires_resource(self, resource):
example_type = self.config.example_type
for name, kwargs in self.config.features.items():
if callable(kwargs):
# a feature extractor function was passed in directly
feature_extractor = kwargs
else:
feature_extractor = get_feature_extractor(example_type, name)
if (
"requirements" in feature_extractor.__dict__
and resource in feature_extractor.requirements
):
return True
return False
[docs] def initialize_resources(self, resource_loader, examples=None, labels=None):
"""Load the required resources for feature extractors. Each feature extractor uses \
@requires decorator to declare required resources. Based on feature list in model config \
a list of required resources are compiled, and the passed in resource loader is then used \
to load the resources accordingly.
Args:
resource_loader (ResourceLoader): application resource loader object
examples (list): Optional. A list of examples.
labels (list): Optional. A parallel list to examples. The gold labels \
for each example.
"""
# get list of resources required by feature extractors
required_resources = self.config.required_resources()
enable_stemming = ENABLE_STEMMING in required_resources
resource_builders = {}
for rname in required_resources:
if rname in self._resources:
continue
if rname == GAZETTEER_RSC:
self._resources[rname] = resource_loader.get_gazetteers()
elif rname == SENTIMENT_ANALYZER:
self._resources[rname] = resource_loader.get_sentiment_analyzer()
elif rname == SYS_TYPES_RSC:
self._resources[rname] = resource_loader.get_sys_entity_types(labels)
elif rname == WORD_FREQ_RSC:
resource_builders[rname] = resource_loader.WordFreqBuilder()
elif rname == CHAR_NGRAM_FREQ_RSC:
l, t = self.config.get_ngram_lengths_and_thresholds(rname)
resource_builders[rname] = resource_loader.CharNgramFreqBuilder(l, t)
elif rname == WORD_NGRAM_FREQ_RSC:
l, t = self.config.get_ngram_lengths_and_thresholds(rname)
resource_builders[rname] = \
resource_loader.WordNgramFreqBuilder(l, t, enable_stemming)
elif rname == QUERY_FREQ_RSC:
resource_builders[rname] = resource_loader.QueryFreqBuilder(enable_stemming)
if resource_builders:
for query in examples:
for rname, builder in resource_builders.items():
builder.add(query)
for rname, builder in resource_builders.items():
self._resources[rname] = builder.get_resource()
# Always initialize the global resource for tokenization, which is not a
# feature-specific resource
self._resources[
"text_preparation_pipeline"
] = resource_loader.get_text_preparation_pipeline()
def _validate_model_configs(self) -> Union[TypeError, ValueError]:
for arg, val in {
"features": self.config.features,
"example_type": self.config.example_type
}.items():
if val is None:
raise TypeError("__init__() missing required argument {!r}".format(arg))
if self.config.params is None and (
self.config.param_selection is None or self.config.param_selection.get("grid") is None
):
raise ValueError(
"__init__() One of 'params' and 'param_selection' is required"
)
[docs]class PytorchModel(AbstractModel):
ALLOWED_CLASSIFIER_TYPES: List[str] = NotImplemented # to be implemented in child classes
def __init__(self, config):
if not _is_module_available('torch'):
raise ImportError("Install the extra 'torch' library by runnning "
"'pip install mindmeld[torch]' to use pytorch based neural models")
super().__init__(config)
self._label_encoder = get_label_encoder(self.config)
self._class_encoder = SKLabelEncoder()
self._query_text_type = None
self._clf = None
[docs] def initialize_resources(self, resource_loader, examples=None, labels=None):
del resource_loader, examples, labels
@staticmethod
def _validate_training_data(examples: List[Any], labels: Union[List[int], List[List[int]]]):
if len(examples) != len(labels):
msg = f"Number of 'labels' ({len(labels)}) must be same as number of 'examples' " \
f"({len(examples)})"
raise AssertionError(msg)
def _set_query_text_type(self, params: Dict = None, default: str = None):
"""
Returns the query text type to use for obtaining training examples from Query objects.
Args:
params (Dict, optional): The config params passed in to train the model
default (str, optional): The default text type to use in case no related configs found
"""
if params is None and self._query_text_type:
# this condition is satisfied during loading of models
return
# While the key 'query_text_type' in config params allows for end-users to configure the
# choice of text_type to be used, it also needs the user to have knowledge about different
# text types in a Query object. Three values are available as of now-
# ["text", "processed_text", "normalized_text"].
query_text_type = params.get("query_text_type", default) if params else default
# consider raw text for pretrained transformer models
if not query_text_type:
if params and EmbedderType(params.get("embedder_type")) == EmbedderType.BERT:
query_text_type = "text"
# if a NoneType value is found, use the processed_text by default
query_text_type = query_text_type or "processed_text"
# validation
allowed_text_types = ["text", "processed_text", "normalized_text"]
if query_text_type not in allowed_text_types:
msg = f"The params 'query_text_type' can only be among " \
f"{allowed_text_types} but found value {query_text_type}."
logger.error(msg)
raise ValueError(msg)
self._query_text_type = query_text_type # this var is dumped and loaded when loading models
def _get_texts_from_examples(self, examples: PQL.QueryIterator) -> List[str]:
"""
Method that decides which text type- processed_text or raw_text -that needs to be used for
neural model training/inference.
Args:
examples (QueryIterator): A list of ProcessedQuery objects.
Returns:
texts (List[str]): A list of strings obtained from the query objects based on the
provided input configs
"""
if not self._query_text_type:
msg = "The instance attribute '_query_text_type' must be set by calling " \
"_set_query_text_type() method before calling the " \
"_get_texts_from_examples() method."
logger.debug(msg)
raise ValueError(msg)
return [getattr(example, self._query_text_type) for example in examples]
[docs]class AbstractModelFactory(ABC):
"""
Abstract class for individual model factories like TextModelFactory and TaggerModelFactory
"""
[docs] @abstractmethod
def get_model_cls(self, config: ModelConfig) -> Type[AbstractModel]:
raise NotImplementedError