Source code for mindmeld.models.labels

# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains classes related to encoding labels for models defined in the models
subpackage."""

import logging

from .helpers import register_label
from .taggers.taggers import (
    get_entities_from_tags,
    get_tags_from_entities,
)
from ..system_entity_recognizer import SystemEntityRecognizer

logger = logging.getLogger(__name__)


[docs]class LabelEncoder: """The label encoder is responsible for converting between rich label objects such as a ProcessedQuery and basic formats a model can interpret. A MindMeld model uses its label encoder at fit time to encode labels into a form it can deal with, and at predict time to decode predictions into objects """ def __init__(self, config): """Initializes an encoder Args: config (ModelConfig): The model """ self.config = config
[docs] @staticmethod def encode(labels, **kwargs): """Transforms a list of label objects into a vector of classes. Args: labels (list): A list of labels to encode """ del kwargs return labels
[docs] @staticmethod def decode(classes, **kwargs): """Decodes a vector of classes into a list of labels Args: classes (list): A list of classes Returns: list: The decoded labels """ del kwargs return classes
[docs]class EntityLabelEncoder(LabelEncoder): def __init__(self, config): """Initializes an encoder Args: config (ModelConfig): The model configuration """ self.config = config def _get_tag_scheme(self): return self.config.model_settings.get("tag_scheme", "IOB").upper()
[docs] def encode(self, labels, **kwargs): """Gets a list of joint app and system IOB tags from each query's entities. Args: labels (list): A list of labels associated with each query kwargs (dict): A dict containing atleast the "examples" key, which is a list of queries to process Returns: list: A list of list of joint app and system IOB tags from each query's entities """ examples = kwargs["examples"] scheme = self._get_tag_scheme() # Here each label is a list of entities for the corresponding example all_tags = [] for idx, label in enumerate(labels): all_tags.append(get_tags_from_entities(examples[idx], label, scheme)) return all_tags
[docs] def decode(self, tags_by_example, **kwargs): """Decodes the labels from the tags passed in for each query Args: tags_by_example (list): A list of tags per query kwargs (dict): A dict containing at least the "examples" key, which is a list of queries to process Returns: list: A list of decoded labels per query """ examples = kwargs["examples"] labels = [ get_entities_from_tags(examples[idx], tags, SystemEntityRecognizer.get_instance()) for idx, tags in enumerate(tags_by_example) ] return labels
register_label("class", LabelEncoder) register_label("entities", EntityLabelEncoder)