Source code for mindmeld.core

# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""This module contains a collection of the core data structures used in MindMeld."""
from copy import deepcopy
import logging
from typing import Optional, List, Dict
import immutables

from .constants import SYSTEM_ENTITY_PREFIX


TEXT_FORM_RAW = 0
TEXT_FORM_PROCESSED = 1
TEXT_FORM_NORMALIZED = 2
TEXT_FORMS = [TEXT_FORM_RAW, TEXT_FORM_PROCESSED, TEXT_FORM_NORMALIZED]

logger = logging.getLogger(__name__)

# The date keys are extracted from here
# https://github.com/wit-ai/duckling_old/blob/a4bc34e3e945d403a9417df50c1fb2172d56de3e/src/duckling/time/obj.clj#L21 # noqa E722
TIME_GRAIN_TO_ORDER = {
    "year": 8,
    "quarter": 7,
    "month": 6,
    "week": 5,
    "day": 4,
    "hour": 3,
    "minute": 2,
    "second": 1,
    "milliseconds": 0,
}


def _sort_by_lowest_time_grain(system_entities):
    return sorted(
        system_entities,
        key=lambda query_entity: TIME_GRAIN_TO_ORDER[
            query_entity.entity.value["grain"]
        ],
    )


[docs]class Bunch(dict): """Dictionary-like object that exposes its keys as attributes. Inspired by scikit learn's Bunches >>> b = Bunch(a=1, b=2) >>> b['b'] 2 >>> b.b 2 >>> b.a = 3 >>> b['a'] 3 >>> b.c = 6 >>> b['c'] 6 """ def __init__(self, **kwargs): super().__init__(kwargs) def __setattr__(self, key, value): self[key] = value def __dir__(self): return self.keys() def __getattr__(self, key): try: return self[key] except KeyError as e: raise AttributeError(key) from e def __setstate__(self, state): pass
[docs]class Span: """Object representing a text span with start and end indices Attributes: start (int): The index from the original text that represents the start of the span end (int): The index from the original text that represents the end of the span """ __slots__ = ["start", "end"] def __init__(self, start, end): assert start <= end, "Span 'start' must be less than or equal to 'end'" self.start = start self.end = end
[docs] def to_cache(self): return { "start": self.start, "end": self.end }
[docs] @staticmethod def from_cache(obj): return Span(**obj)
[docs] def to_dict(self): """Converts the span into a dictionary""" return {"start": self.start, "end": self.end}
[docs] def slice(self, obj): """Returns the slice of the object for this span Args: obj: The object to slice Returns: The slice of the passed in object for this span """ return obj[self.start : self.end + 1]
[docs] def shift(self, offset): """Shifts a span by offset Args: offset (int): The number to change start and end by """ return Span(self.start + offset, self.end + offset)
[docs] def has_overlap(self, other): """Determines whether two spans overlap.""" return self.end >= other.start and other.end >= self.start
[docs] @staticmethod def get_largest_non_overlapping_candidates(spans): """Finds the set of the largest non-overlapping candidates. Args: spans (list): List of tuples representing candidate spans (start_index, end_index + 1). Returns: selected_spans (list): List of the largest non-overlapping spans. """ spans.sort(reverse=True) selected_spans = [] for span in spans: has_overlaps = [ span.has_overlap(selected_span) for selected_span in selected_spans ] if not any(has_overlaps): selected_spans.append(span) return selected_spans
def __iter__(self): for index in range(self.start, self.end + 1): yield index def __len__(self): return self.end - self.start + 1 def __eq__(self, other): if isinstance(other, self.__class__): return self.start == other.start and self.end == other.end return NotImplemented def __gt__(self, other): if isinstance(other, self.__class__): return len(self) > len(other) return NotImplemented def __ge__(self, other): if isinstance(other, self.__class__): return len(self) >= len(other) return NotImplemented def __lt__(self, other): if isinstance(other, self.__class__): return len(self) < len(other) return NotImplemented def __le__(self, other): if isinstance(other, self.__class__): return len(self) <= len(other) return NotImplemented def __ne__(self, other): if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented def __repr__(self): return "{}(start={}, end={})".format( self.__class__.__name__, self.start, self.end )
[docs]class Query: """The query object is responsible for processing and normalizing raw user text input so that classifiers can deal with it. A query stores three forms of text: raw text, processed text, and normalized text. The query object is also responsible for translating text ranges across these forms. Attributes: raw_text (str): the original input text processed_text (str): the text after it has been preprocessed. The pre-processing happens at the application level and is generally used for special characters normalized_tokens (tuple of str): a list of normalized tokens system_entity_candidates (tuple): A list of system entities extracted from the text locale (str, optional): The locale representing the ISO 639-1 language code and \ ISO3166 alpha 2 country code separated by an underscore character. language (str, optional): The language code representing ISO 639-1 language codes. time_zone (str): The IANA id for the time zone in which the query originated such as 'America/Los_Angeles' timestamp (long, optional): A unix timestamp used as the reference time If not specified, the current system time is used. If `time_zone` is not also specified, this parameter is ignored stemmed_tokens (list): A sequence of stemmed tokens for the query text """ # TODO: look into using __slots__ def __init__( self, raw_text, processed_text, normalized_tokens, char_maps, locale=None, language=None, time_zone=None, timestamp=None, stemmed_tokens=None, ): """Creates a query object Args: raw_text (str): the original input text processed_text (str): the input text after it has been preprocessed normalized_tokens (list of dict): List tokens outputted by a tokenizer char_maps (dict): Mappings between character indices in raw, processed and normalized text """ self._normalized_tokens = normalized_tokens norm_text = " ".join([t["entity"] for t in self._normalized_tokens]) self._texts = (raw_text, processed_text, norm_text) self._char_maps = char_maps self.system_entity_candidates = () self._locale = locale self._language = language self._time_zone = time_zone self._timestamp = timestamp self.stemmed_tokens = stemmed_tokens or tuple()
[docs] def char_maps_to_cache(self): # dump the tuple keys to a string and trim the ( ) symbols from the end return {str(k)[1:-1] : v for k,v in self._char_maps.items()}
[docs] def to_cache(self): return { "raw_text": self.text, "processed_text": self.processed_text, "normalized_tokens": self._normalized_tokens, "char_maps": self.char_maps_to_cache(), "locale": self._locale, "language": self._language, "time_zone": self._time_zone, "timestamp": self._timestamp, "stemmed_tokens": self.stemmed_tokens, "system_entity_candidates": [e.to_cache() for e in self.system_entity_candidates] }
[docs] @staticmethod def char_maps_from_cache(obj): result = {} for k,v in obj.items(): # Convert string key back into tuple tuple_key = tuple(map(int, k.split(","))) # Convert string key for dicts back into integer result[tuple_key] = {int(k2): v2 for k2, v2 in v.items()} return result
[docs] @staticmethod def from_cache(obj): system_entity_candidates = obj.pop("system_entity_candidates") obj["char_maps"] = Query.char_maps_from_cache(obj["char_maps"]) result = Query(**obj) result.system_entity_candidates = [ Entity.from_cache_typed(e) for e in system_entity_candidates ] return result
@property def text(self): """The original input text""" return self._texts[TEXT_FORM_RAW] @property def processed_text(self): """The input text after it has been preprocessed""" return self._texts[TEXT_FORM_PROCESSED] @property def normalized_text(self): """The normalized input text""" return self._texts[TEXT_FORM_NORMALIZED] @property def stemmed_text(self): """The stemmed input text""" return " ".join(self.stemmed_tokens) @property def normalized_tokens(self): """The tokens of the normalized input text""" return tuple((token["entity"] for token in self._normalized_tokens)) @property def language(self): """Language of the query specified using a 639-2 code.""" return self._language @property def locale(self): """The locale representing the ISO 639-1/2 language code and ISO3166 alpha 2 country code separated by an underscore character.""" return self._locale @property def time_zone(self): """The IANA id for the time zone in which the query originated such as 'America/Los_Angeles'. """ return self._time_zone @property def timestamp(self): """A unix timestamp for when the time query was created. If `time_zone` is None, this parameter is ignored. """ return self._timestamp
[docs] def get_verbose_normalized_tokens(self): """This function returns a list of dictionaries containing details of each normalized token """ return self._normalized_tokens
[docs] def get_text_form(self, form): """Programmatically retrieves text by form Args: form (int): A valid text form (TEXT_FORM_RAW, TEXT_FORM_PROCESSED, or TEXT_FORM_NORMALIZED) Returns: str: The requested text """ return self._texts[form]
[docs] def get_system_entity_candidates(self, sys_types): """ Args: sys_types (set of str): A set of entity types to select Returns: list: Returns candidate system entities of the types specified """ return [e for e in self.system_entity_candidates if e.entity.type in sys_types]
[docs] def transform_span(self, text_span, form_in, form_out): """Transforms a text range from one form to another. Args: text_span (Span): the text span being transformed form_in (int): the input text form. Should be one of TEXT_FORM_RAW, TEXT_FORM_PROCESSED or TEXT_FORM_NORMALIZED form_out (int): the output text form. Should be one of TEXT_FORM_RAW, TEXT_FORM_PROCESSED or TEXT_FORM_NORMALIZED Returns: tuple: the equivalent range of text in the output form """ return Span( self.transform_index(text_span.start, form_in, form_out), self.transform_index(text_span.end, form_in, form_out), )
[docs] def transform_index(self, index, form_in, form_out): """Transforms a text index from one form to another. Args: index (int): the index being transformed form_in (int): the input form. should be one of TEXT_FORM_RAW form_out (int): the output form Returns: int: the equivalent index of text in the output form """ if form_in not in TEXT_FORMS or form_out not in TEXT_FORMS: raise ValueError("Invalid text form") if form_in > form_out: while form_in > form_out: index = self._unprocess_index(index, form_in) form_in -= 1 else: while form_in < form_out: index = self._process_index(index, form_in) form_in += 1 return index
def _process_index(self, index, form_in): if form_in == TEXT_FORM_NORMALIZED: raise ValueError( "'{}' form cannot be processed".format(TEXT_FORM_NORMALIZED) ) mapping_key = (form_in, (form_in + 1)) try: mapping = self._char_maps[mapping_key] except KeyError: # mapping doesn't exist -> use identity return index # None for mapping means 1-1 mapping try: return mapping[index] if mapping else index except KeyError as e: raise ValueError("Invalid index {}".format(index)) from e def _unprocess_index(self, index, form_in): if form_in == TEXT_FORM_RAW: raise ValueError("'{}' form cannot be unprocessed".format(TEXT_FORM_RAW)) mapping_key = (form_in, (form_in - 1)) try: mapping = self._char_maps[mapping_key] except KeyError: # mapping doesn't exist -> use identity return index # None for mapping means 1-1 mapping try: return mapping[index] if mapping else index except KeyError as e: raise ValueError("Invalid index {}".format(index)) from e
[docs] def get_token_ngram_raw_ngram_span(self, tokens, start_token_index, end_token_index): token_ngram = tuple( [token['entity'] for token in tokens[start_token_index:end_token_index + 1]] ) last_raw_start = tokens[end_token_index]['raw_start'] last_raw_entity = tokens[end_token_index]['entity'] first_raw_start = tokens[start_token_index]['raw_start'] result_span = Span(first_raw_start, last_raw_start + len(last_raw_entity) - 1) raw_ngram = self.text[result_span.start: result_span.end + 1] return token_ngram, raw_ngram, result_span
def __eq__(self, other): if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ return NotImplemented def __ne__(self, other): if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented def __repr__(self): return "<{} {!r}>".format(self.__class__.__name__, self.text)
[docs]class ProcessedQuery: """A processed query contains a query and the additional metadata that has been labeled or predicted. Attributes: query (Query): The underlying query object. domain (str): The domain of the query entities (list): A list of entities present in this query intent (str): The intent of the query is_gold (bool): Indicates whether the details in this query were predicted or human labeled nbest_transcripts_queries (list): A list of n best transcript queries nbest_transcripts_entities (list): A list of lists of entities for each query nbest_aligned_entities (list): A list of lists of aligned entities confidence (dict): A dictionary of the class probas for the domain and intent classifier """ def __init__( self, query, domain=None, intent=None, entities=None, is_gold=False, nbest_transcripts_queries=None, nbest_transcripts_entities=None, nbest_aligned_entities=None, confidence=None, ): self.query = query self.domain = domain self.intent = intent self.entities = None if entities is None else tuple(entities) self.is_gold = is_gold self.nbest_transcripts_queries = nbest_transcripts_queries self.nbest_transcripts_entities = nbest_transcripts_entities self.nbest_aligned_entities = nbest_aligned_entities self.confidence = confidence
[docs] def to_dict(self): """Converts the processed query into a dictionary""" base = { "text": self.query.text, "domain": self.domain, "intent": self.intent, "entities": None if self.entities is None else [e.to_dict() for e in self.entities], } if self.nbest_transcripts_queries: base["nbest_transcripts_text"] = [ q.text for q in self.nbest_transcripts_queries ] if self.nbest_transcripts_entities: base["nbest_transcripts_entities"] = [ [e.to_dict() for e in n_entities] for n_entities in self.nbest_transcripts_entities ] if self.nbest_aligned_entities: base["nbest_aligned_entities"] = [ [{"text": e.entity.text, "type": e.entity.type} for e in n_entities] for n_entities in self.nbest_aligned_entities ] if self.confidence: base["confidences"] = self.confidence return base
[docs] def to_cache(self): obj = { "query": self.query.to_cache(), "domain": self.domain, "intent": self.intent, "entities": [e.to_cache() for e in self.entities], "is_gold": self.is_gold, "confidence": self.confidence } if self.nbest_transcripts_queries: obj["nbest_transcripts_queries"] = [ q.to_cache() for q in self.nbest_transcripts_queries ] if self.nbest_transcripts_entities: obj["nbest_transcripts_entities"] = [ e.to_cache() for e in self.nbest_transcripts_entities ] if self.nbest_aligned_entities: obj["nbest_aligned_entities"] = [ e.to_cache() for e in self.nbest_aligned_entities ] return obj
[docs] @staticmethod def from_cache(obj): obj["query"] = Query.from_cache(obj["query"]) obj["entities"] = [Entity.from_cache_typed(e) for e in obj["entities"]] if "nbest_transcripts_queries" in obj: obj["nbest_transcripts_queries"] = [ Query.from_cache(q) for q in obj["nbest_transcripts_queries"] ] if "nbest_transcripts_entities" in obj: obj["nbest_transcripts_entities"] = [ Entity.from_cache_typed(e) for e in obj["nbest_transcripts_entities"] ] if "nbest_aligned_entities" in obj: obj["nbest_aligned_entities"] = [ Entity.from_cache_typed(e) for e in obj["nbest_aligned_entities"] ] return ProcessedQuery(**obj)
def __eq__(self, other): if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ return NotImplemented def __ne__(self, other): if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented def __repr__(self): msg = "<{} {!r}, domain: {!r}, intent: {!r}, {!r} entities{}>" return msg.format( self.__class__.__name__, self.query.text, self.domain, self.intent, len(self.entities), ", gold" if self.is_gold else "", )
[docs]class NestedEntity: """An entity with the context of the query it came from, along with \ information like the entity's parent and children. Attributes: texts (tuple): Tuple containing the three forms of text: raw text, \ processed text, and normalized text spans (tuple): Tuple containing the character index spans of the \ text for this entity for each text form token_spans (tuple): Tuple containing the token index spans of the \ text for this entity for each text form entity (Entity): The entity object parent (NestedEntity): The parent of the nested entity children (tuple of NestedEntity): A tuple of children nested entities """ def __init__(self, texts, spans, token_spans, entity, children=None): self._texts = texts self._spans = spans self._token_spans = token_spans self.entity = entity self.parent = None for child in children or (): child.parent = self if children: self.children = tuple(sorted(children, key=lambda c: c.span.start)) else: self.children = None
[docs] def to_cache(self): obj = { "class": self.__class__.__name__, "texts": self._texts, "spans": [s.to_cache() for s in self._spans], "token_spans": [s.to_cache() for s in self._token_spans], "entity": self.entity.to_cache(), } if self.children: obj["children"] = [e.to_cache() for e in self.children] return obj
[docs] @staticmethod def from_cache(obj): obj["spans"] = [Span.from_cache(s) for s in obj["spans"]] obj["token_spans"] = [Span.from_cache(s) for s in obj["token_spans"]] obj["entity"] = Entity.from_cache_typed(obj["entity"]) if "children" in obj: obj["children"] = [Entity.from_cache_typed(e) for e in obj["children"]] return NestedEntity(**obj)
[docs] def with_children(self, children): """Creates a copy of this entity with the provided children""" return self.__class__( self._texts, self._spans, self._token_spans, self.entity, children )
[docs] @classmethod def from_query( cls, query, span=None, normalized_span=None, entity_type=None, role=None, entity=None, parent_offset=None, children=None, ): """Creates an entity node using a parent entity node Args: query (Query): Description span (Span): The span of the entity in the query's raw text normalized_span (None, optional): The span of the entity in the query's normalized text entity_type (str, optional): The entity type. One of this and entity must be provided role (str, optional): The entity role. Ignored if entity is provided. entity (Entity, optional): The entity. One of this and entity must be provided parent_offset (int): The offset of the parent in the query children (None, optional): Description Returns: the created entity """ def _get_token_start(full_norm_text, span_out): """ Calculate the start of a token using the normalized tokens combined as a string and delimited by space. Args: full_norm_text (str): Normalized tokens combined by space as a single string. span_out (Span): Span of the token of interest in the full_norm_text. Return: tok_start (int): The starting token index. """ # The span range is till the span_out or max to the second last char tok_start = 0 span_range = min(span_out.start, len(full_norm_text) - 1) for idx, current_char in enumerate(full_norm_text[:span_range]): # Increment the counter only if a whitespace follows a non-whitespace next_char = full_norm_text[idx + 1] if not current_char.isspace() and next_char.isspace(): tok_start += 1 return tok_start def _get_form_details(query_span, offset, form_in, form_out): """ Get the transformed text, transformed text span, and token index span for a given text token. By default, the normalized text form is used when calculating the token index span as it accounts for custom tokenization. Args: query_span (Span): The text span for the token of interest offset (int): Offset the final indexing to parent's indexing. form_in (int): Integer value representing starting text format form_out (int): Integer value representing final text format Returns: text (str): The transformed version of the input text. span_out (Span): Span of the token of interest in the norm_form_full_text. tok_span (Span): Span of normalized tokens for the given text. """ span_out = query.transform_span(query_span, form_in, form_out) full_text = query.get_text_form(form_out) text = span_out.slice(full_text) # Calculate Token Span Using the Normalized Text Form # Creating a copy to avoid modifying the original query object query_norm_form = deepcopy(query) norm_form_span_out = query_norm_form.transform_span( query_span, form_in, TEXT_FORM_NORMALIZED ) norm_form_full_text = query_norm_form.get_text_form(TEXT_FORM_NORMALIZED) norm_form_token_text = norm_form_span_out.slice(norm_form_full_text) tok_start = _get_token_start(norm_form_full_text, norm_form_span_out) # Using a min token len of 1 avoids token span of (x, x - 1) which can become negative. norm_form_token_text_len = len(norm_form_token_text.split()) or 1 tok_span = Span( start=tok_start, end=tok_start + norm_form_token_text_len - 1 ) # convert span from query's indexing to parent's indexing if offset is not None: # Calculate Token Offset Based on the Query's Normalized Form norm_form_offset_out = query_norm_form.transform_index( offset, form_in, TEXT_FORM_NORMALIZED ) tok_offset = len(norm_form_full_text[:norm_form_offset_out].split()) tok_span.shift(-tok_offset) # Calculate the Original Query's Offset Based on New Form offset_out = query.transform_index(offset, form_in, form_out) span_out = span_out.shift(-offset_out) return text, span_out, tok_span if span: query_span = ( span.shift(parent_offset) if parent_offset is not None else span ) form_in = TEXT_FORM_RAW elif normalized_span: query_span = ( normalized_span.shift(parent_offset) if parent_offset is not None else normalized_span ) form_in = TEXT_FORM_NORMALIZED texts, spans, tok_spans = list( zip( *[ _get_form_details(query_span, parent_offset, form_in, form_out) for form_out in TEXT_FORMS ] ) ) if entity is None: if entity_type is None: raise ValueError("Either 'entity' or 'entity_type' must be specified") entity = Entity(texts[0], entity_type, role=role) return cls(texts, spans, tok_spans, entity, children)
[docs] @staticmethod def get_largest_non_overlapping_entities(candidates, get_span_func): """ This function filters out overlapping entity spans Args: candidates (iterable): A iterable of candidates to filter based on span get_span_func (function): A function that accesses the span from each candidate Returns: list: A list of non-overlapping candidates """ final_spans = Span.get_largest_non_overlapping_candidates( [get_span_func(candidate) for candidate in candidates]) final_candidates = [] for span in final_spans: for candidate in candidates: if span == get_span_func(candidate): final_candidates.append(candidate) break return final_candidates
[docs] def to_dict(self): """Converts the query entity into a dictionary""" base = self.entity.to_dict() base["span"] = self.span.to_dict() if self.children: base["children"] = [c.to_dict() for c in self.children] return base
@property def text(self): """The original input text span""" return self._texts[TEXT_FORM_RAW] @property def processed_text(self): """The input text after it has been preprocessed""" return self._texts[TEXT_FORM_PROCESSED] @property def normalized_text(self): """The normalized input text""" return self._texts[TEXT_FORM_NORMALIZED] @property def span(self): """The span of original input text span""" return self._spans[TEXT_FORM_RAW] @property def processed_span(self): """The span of the preprocessed text span""" return self._spans[TEXT_FORM_PROCESSED] @property def normalized_span(self): """The span of the normalized text span""" return self._spans[TEXT_FORM_NORMALIZED] @property def token_span(self): """The token_span of original input text span""" return self._token_spans[TEXT_FORM_RAW] @property def processed_token_span(self): """The token_span of the preprocessed text span""" return self._token_spans[TEXT_FORM_PROCESSED] @property def normalized_token_span(self): """The token_span of the normalized text span""" return self._token_spans[TEXT_FORM_NORMALIZED] def __eq__(self, other): if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ return NotImplemented def __ne__(self, other): if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented def __str__(self): return "{}{} '{}' {}-{}".format( self.entity.type, ":" + self.entity.role if self.entity.role else "", self.text, self.span.start, self.span.end, ) def __repr__(self): msg = "<{} {!r} ({!r}) char: [{!r}-{!r}], tok: [{!r}-{!r}]>" return msg.format( self.__class__.__name__, self.text, self.entity.type, self.span.start, self.span.end, self.token_span.start, self.token_span.end, )
[docs]class QueryEntity(NestedEntity): """An entity with the context of the query it came from. Attributes: text (str): The raw text that was processed into this entity processed_text (str): The processed text that was processed into this entity normalized_text (str): The normalized text that was processed into this entity span (Span): The character index span of the raw text that was processed into this entity processed_span (Span): The character index span of the raw text that was processed into this entity span (Span): The character index span of the raw text that was processed into this entity start (int): The character index start of the text range that was processed into this entity. This index is based on the normalized text of the query passed in. end (int): The character index end of the text range that was processed into this entity. This index is based on the normalized text of the query passed in. """
[docs]class Entity: """An Entity is any important piece of text that provides more information about the user intent. Attributes: text (str): The text contents that span the entity type (str): The type of the entity role (str): The role of the entity value (dict): The resolved value of the entity display_text (str): A human readable text representation of the entity for use in natural language responses. confidence (float): A confidence value from 0 to 1 about how confident the entity recognizer was for the given class label. is_system_entity (bool): True if the entity is a system entity """ # TODO: look into using __slots__ def __init__( self, text, entity_type, role=None, value=None, display_text=None, confidence=None, ): self.text = text self.type = entity_type self.role = role self.value = value self.display_text = display_text self.confidence = confidence self.is_system_entity = self.__class__.is_system_entity(entity_type)
[docs] @staticmethod def value_to_cache(value): result = value if value is not None: result = value.copy() if "children" in value: result["children"] = [e.to_cache() for e in value["children"]] return result
[docs] def to_cache(self): return { "class": self.__class__.__name__, "text": self.text, "entity_type": self.type, "role": self.role, "value": self.value_to_cache(self.value), "display_text": self.display_text, "confidence": self.confidence }
[docs] @staticmethod def from_cache(obj): if "children" in obj: obj["children"] = [Entity.from_cache_typed(e) for e in obj["children"]] return Entity(**obj)
[docs] @staticmethod def from_cache_typed(obj): """ Function to instantiate a cached Entity by the class type which was serialized when it's to_cache() function was called. """ entity_class = obj.pop("class") return Entity.entity_class_map[entity_class].from_cache(obj)
[docs] @staticmethod def is_system_entity(entity_type): # pylint: disable=method-hidden """Checks whether the provided entity type is a MindMeld-recognized system entity. Args: entity_type (str): An entity type Returns: bool: True if the entity is a system entity type, else False """ return entity_type.startswith(SYSTEM_ENTITY_PREFIX)
[docs] def to_dict(self): """Converts the entity into a dictionary""" base = {"text": self.text, "type": self.type, "role": self.role} for field in ["value", "display_text", "confidence"]: value = getattr(self, field) if value is not None: base[field] = value return base
def __eq__(self, other): if isinstance(other, self.__class__): return self.__dict__ == other.__dict__ return NotImplemented def __ne__(self, other): if isinstance(other, self.__class__): return not self.__eq__(other) return NotImplemented def __repr__(self): text = self.display_text or self.text return "<{} {!r} ({!r})>".format(self.__class__.__name__, text, self.type)
Entity.entity_class_map = { Entity.__name__: Entity, QueryEntity.__name__: QueryEntity, NestedEntity.__name__: NestedEntity }
[docs]class FormEntity: """A form entity is used for defining custom objects for the entity form used in AutoEntityFilling (slot-filling). Attributes: entity (str): Entity name role (str, optional): The role of the entity responses(list/str, optional): Message(s) for prompting the user for missing entities retry_response (list/str, optional): Message(s) for re-prompting users. If not provided, defaults to responses value (str, optional): The resolved value of the entity default_eval(bool, optional): Use system validation (default: True) hints(list, optional): Developer defined list of keywords to verify the user input against custom_eval(str, optional): custom validation function name (should return either bool: validated or not) or a custom resolved value for the entity. If custom resolved value is returned, the slot response is considered to be valid. """ def __init__( self, entity: str, role: Optional[str] = None, responses: Optional[List[str]] = None, retry_response: Optional[List[str]] = None, value: Optional[Dict] = None, default_eval: Optional[bool] = True, hints: Optional[List[str]] = None, custom_eval: Optional[str] = None, ): self.entity = entity self.role = role if isinstance(responses, str): responses = [responses] self.responses = responses or [ "Please provide value for: {}".format(self.entity) ] if isinstance(retry_response, str): retry_response = [retry_response] self.retry_response = retry_response or self.responses self.value = value self.default_eval = default_eval self.hints = hints self.custom_eval = custom_eval if not self.entity or not isinstance(self.entity, str): raise TypeError("Entity cannot be empty.") if self.custom_eval and not isinstance(self.custom_eval, str): raise TypeError("'custom_eval' function should be a string.")
[docs] def to_dict(self): """Converts the entity into a dictionary""" base = {} for field in self.__dict__: val = getattr(self, field) if val is not None: if isinstance(val, immutables.Map): val = dict(val) base[field] = val return base
[docs]class CallableRegistry: """A registration class to map callable object names to corresponding objects.""" def __init__(self): self._callable_functions_registry = {} @property def functions_registry(self): """Getter for functions registry""" return self._callable_functions_registry @functions_registry.setter def functions_registry(self, func_name, func): """Populates the function registry map. Args: func_name (str): Name to be used as key for the function. func (func): Callable function. """ self._callable_functions_registry[func_name] = func
[docs]def resolve_entity_conflicts(query_entities): """This method takes a list containing query entities for a query, and resolves any entity conflicts. The resolved list is returned. If two entities in a query conflict with each other, use the following logic: - If the target entity is a subset of another entity, then delete the target entity. - If the target entity shares the identical span as another entity, then keep the one with the highest confidence. - If the target entity overlaps with another entity, then keep the one with the highest confidence. Args: entities (list of QueryEntity): A list of query entities to resolve Returns: list of QueryEntity: A filtered list of query entities """ filtered = query_entities i = 0 while i < len(filtered): include_target = True target = filtered[i] j = i + 1 while j < len(filtered): other = filtered[j] if _is_superset(target, other) and not _is_same_span(target, other): logger.debug( "Removing {{%s|%s}} entity in query %d since it is a " "subset of another.", other.text, other.entity.type, i, ) del filtered[j] continue if _is_subset(target, other) and not _is_same_span(target, other): logger.debug( "Removing {{%s|%s}} entity in query %d since it is a " "subset of another.", target.text, target.entity.type, i, ) del filtered[i] include_target = False break if _is_same_span(target, other) or _is_overlapping(target, other): if target.entity.confidence >= other.entity.confidence: logger.debug( "Removing {{%s|%s}} entity in query %d since it overlaps " "with another.", other.text, other.entity.type, i, ) del filtered[j] continue if target.entity.confidence < other.entity.confidence: logger.debug( "Removing {{%s|%s}} entity in query %d since it overlaps " "with another.", target.text, target.entity.type, i, ) del filtered[i] include_target = False break j += 1 if include_target: i += 1 return filtered
def _is_subset(target, other): return (target.start >= other.start) and (target.end <= other.end) def _is_superset(target, other): return (target.start <= other.start) and (target.end >= other.end) def _is_same_span(target, other): return _is_superset(target, other) and _is_subset(target, other) def _is_overlapping(target, other): overlap = _get_overlap(target, other) return overlap and not _is_subset(target, other) and not _is_superset(target, other) def _get_overlap(target, other): target_range = range(target.start, target.end + 1) predicted_range = range(other.start, other.end + 1) return set(target_range).intersection(predicted_range)