# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This module contains base classes for models defined in the models subpackage."""
import logging
from collections import namedtuple
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
from sklearn.metrics import precision_recall_fscore_support as score
from .helpers import (
ENTITIES_LABEL_TYPE,
entity_seqs_equal,
get_label_encoder,
)
from .taggers.taggers import (
BoundaryCounts,
get_boundary_counts,
)
logger = logging.getLogger(__name__)
[docs]class EvaluatedExample(
namedtuple(
"EvaluatedExample", ["example", "expected", "predicted", "probas", "label_type"]
)
):
"""Represents the evaluation of a single example
Attributes:
example: The example being evaluated
expected: The expected label for the example
predicted: The predicted label for the example
proba (dict): Maps labels to their predicted probabilities
label_type (str): One of CLASS_LABEL_TYPE or ENTITIES_LABEL_TYPE
"""
@property
def is_correct(self):
# For entities compare just the type, span and text for each entity.
if self.label_type == ENTITIES_LABEL_TYPE:
return entity_seqs_equal(self.expected, self.predicted)
# For other label_types compare the full objects
else:
return self.expected == self.predicted
[docs]class RawResults:
"""Represents the raw results of a set of evaluated examples. Useful for generating
stats and graphs.
Attributes:
predicted (list): A list of predictions. For sequences this is a list of lists, and for
standard classifieris this is a 1d array. All classes are in their numeric
representations for ease of use with evaluation libraries and graphing.
expected (list): Same as predicted but contains the true or gold values.
text_labels (list): A list of all the text label values, the index of the text label in
this array is the numeric label
predicted_flat (list): (Optional): For sequence models this is a flattened list of all
predicted tags (1d array)
expected_flat (list): (Optional): For sequence models this is a flattened list of all gold
tags
"""
def __init__(
self, predicted, expected, text_labels, predicted_flat=None, expected_flat=None
):
self.predicted = predicted
self.expected = expected
self.text_labels = text_labels
self.predicted_flat = predicted_flat
self.expected_flat = expected_flat
[docs]class ModelEvaluation(namedtuple("ModelEvaluation", ["config", "results"])):
"""Represents the evaluation of a model at a specific configuration
using a collection of examples and labels.
Attributes:
config (ModelConfig): The model config used during evaluation.
results (list of EvaluatedExample): A list of the evaluated examples.
"""
def __init__(self, config, results):
del results
self.label_encoder = get_label_encoder(config)
[docs] def get_accuracy(self):
"""The accuracy represents the share of examples whose predicted labels
exactly matched their expected labels.
Returns:
float: The accuracy of the model.
"""
num_examples = len(self.results)
num_correct = len([e for e in self.results if e.is_correct])
return float(num_correct) / float(num_examples)
def __repr__(self):
num_examples = len(self.results)
num_correct = len(list(self.correct_results()))
accuracy = self.get_accuracy()
msg = "<{} score: {:.2%}, {} of {} example{} correct>"
return msg.format(
self.__class__.__name__,
accuracy,
num_correct,
num_examples,
"" if num_examples == 1 else "s",
)
[docs] def correct_results(self):
"""
Returns:
iterable: Collection of the examples which were correct
"""
for result in self.results:
if result.is_correct:
yield result
[docs] def incorrect_results(self):
"""
Returns:
iterable: Collection of the examples which were incorrect
"""
for result in self.results:
if not result.is_correct:
yield result
[docs] def get_stats(self):
"""
Returns a structured stats object for evaluation.
Returns:
dict: Structured dict containing evaluation statistics. Contains precision, \
recall, f scores, support, etc.
"""
raise NotImplementedError
[docs] def print_stats(self):
"""
Prints a useful stats table for evaluation.
Returns:
dict: Structured dict containing evaluation statistics. Contains precision, \
recall, f scores, support, etc.
"""
raise NotImplementedError
[docs] def raw_results(self):
"""
Exposes raw vectors of expected and predicted for data scientists to use for any additional
evaluation metrics or to generate graphs of their choice.
Returns:
(tuple): tuple containing:
* NamedTuple: RawResults named tuple containing
* expected: vector of predicted classes (numeric value)
* predicted: vector of gold classes (numeric value)
* text_labels: a list of all the text label values, the index of the text label in
* this array is the numeric label
"""
raise NotImplementedError
@staticmethod
def _update_raw_result(label, text_labels, vec):
"""
Helper method for updating the text to numeric label vectors
Returns:
(tuple): tuple containing:
* text_labels: The updated text_labels array
* vec: The updated label vector with the given label appended
"""
if label not in text_labels:
text_labels.append(label)
vec.append(text_labels.index(label))
return text_labels, vec
def _get_common_stats(self, raw_expected, raw_predicted, text_labels):
"""
Prints a useful stats table and returns a structured stats object for evaluation.
Returns:
dict: Structured dict containing evaluation statistics. Contains precision, \
recall, f scores, support, etc.
"""
labels = range(len(text_labels))
confusion_stats = self._get_confusion_matrix_and_counts(
y_true=raw_expected, y_pred=raw_predicted
)
stats_overall = self._get_overall_stats(
y_true=raw_expected, y_pred=raw_predicted, labels=labels
)
counts_overall = confusion_stats["counts_overall"]
stats_overall["tp"] = counts_overall.tp
stats_overall["tn"] = counts_overall.tn
stats_overall["fp"] = counts_overall.fp
stats_overall["fn"] = counts_overall.fn
class_stats = self._get_class_stats(
y_true=raw_expected, y_pred=raw_predicted, labels=labels
)
counts_by_class = confusion_stats["counts_by_class"]
class_stats["tp"] = counts_by_class.tp
class_stats["tn"] = counts_by_class.tn
class_stats["fp"] = counts_by_class.fp
class_stats["fn"] = counts_by_class.fn
return {
"stats_overall": stats_overall,
"class_labels": text_labels,
"class_stats": class_stats,
"confusion_matrix": confusion_stats["confusion_matrix"],
}
@staticmethod
def _get_class_stats(y_true, y_pred, labels):
"""
Method for getting some basic statistics by class.
Returns:
dict: A structured dictionary containing precision, recall, f_beta, and support \
vectors (1 x number of classes)
"""
precision, recall, f_beta, support = score(
y_true=y_true, y_pred=y_pred, labels=labels
)
stats = {
"precision": precision,
"recall": recall,
"f_beta": f_beta,
"support": support,
}
return stats
@staticmethod
def _get_overall_stats(y_true, y_pred, labels):
"""
Method for getting some overall statistics.
Returns:
dict: A structured dictionary containing scalar values for f1 scores and overall \
accuracy.
"""
f1_weighted = f1_score(
y_true=y_true, y_pred=y_pred, labels=labels, average="weighted"
)
f1_macro = f1_score(
y_true=y_true, y_pred=y_pred, labels=labels, average="macro"
)
f1_micro = f1_score(
y_true=y_true, y_pred=y_pred, labels=labels, average="micro"
)
accuracy = accuracy_score(y_true=y_true, y_pred=y_pred)
stats_overall = {
"f1_weighted": f1_weighted,
"f1_macro": f1_macro,
"f1_micro": f1_micro,
"accuracy": accuracy,
}
return stats_overall
@staticmethod
def _get_confusion_matrix_and_counts(y_true, y_pred):
"""
Generates the confusion matrix where each element Cij is the number of observations known to
be in group i predicted to be in group j
Returns:
dict: Contains 2d array of the confusion matrix, and an array of tp, tn, fp, fn values
"""
confusion_mat = confusion_matrix(y_true=y_true, y_pred=y_pred)
tp_arr, tn_arr, fp_arr, fn_arr = [], [], [], []
num_classes = len(confusion_mat)
for class_index in range(num_classes):
# tp is C_classindex, classindex
tp = confusion_mat[class_index][class_index]
tp_arr.append(tp)
# tn is the sum of Cij where i or j are not class_index
mask = np.ones((num_classes, num_classes))
mask[:, class_index] = 0
mask[class_index, :] = 0
tn = np.sum(mask * confusion_mat)
tn_arr.append(tn)
# fp is the sum of Cij where j is class_index but i is not
mask = np.zeros((num_classes, num_classes))
mask[:, class_index] = 1
mask[class_index, class_index] = 0
fp = np.sum(mask * confusion_mat)
fp_arr.append(fp)
# fn is the sum of Cij where i is class_index but j is not
mask = np.zeros((num_classes, num_classes))
mask[class_index, :] = 1
mask[class_index, class_index] = 0
fn = np.sum(mask * confusion_mat)
fn_arr.append(fn)
Counts = namedtuple("Counts", ["tp", "tn", "fp", "fn"])
return {
"confusion_matrix": confusion_mat,
"counts_by_class": Counts(tp_arr, tn_arr, fp_arr, fn_arr),
"counts_overall": Counts(
sum(tp_arr), sum(tn_arr), sum(fp_arr), sum(fn_arr)
),
}
def _print_class_stats_table(self, stats, text_labels, title="Statistics by class"):
"""
Helper for printing a human readable table for class statistics
Returns:
None
"""
title_format = "{:>20}" + "{:>12}" * (len(stats))
common_stats = [
"f_beta",
"precision",
"recall",
"support",
"tp",
"tn",
"fp",
"fn",
]
stat_row_format = (
"{:>20}"
+ "{:>12.3f}" * 3
+ "{:>12.0f}" * 5
+ "{:>12.3f}" * (len(stats) - len(common_stats))
)
table_titles = common_stats + [
stat for stat in stats.keys() if stat not in common_stats
]
print(title + ": \n")
print(title_format.format("class", *table_titles))
for label_index, label in enumerate(text_labels):
row = []
for stat in table_titles:
row.append(stats[stat][label_index])
print(stat_row_format.format(self._truncate_label(label, 18), *row))
print("\n\n")
def _print_class_matrix(self, matrix, text_labels):
"""
Helper for printing a human readable class by class table for displaying
a confusion matrix
Returns:
None
"""
# Doesn't print if there isn't enough space to display the full matrix.
if len(text_labels) > 10:
print(
"Not printing confusion matrix since it is too large. The full matrix is still"
" included in the dictionary returned from get_stats()."
)
return
labels = range(len(text_labels))
title_format = "{:>15}" * (len(labels) + 1)
stat_row_format = "{:>15}" * (len(labels) + 1)
table_titles = [
self._truncate_label(text_labels[label], 10) for label in labels
]
print("Confusion matrix: \n")
print(title_format.format("", *table_titles))
for label_index, label in enumerate(text_labels):
print(
stat_row_format.format(
self._truncate_label(label, 10), *matrix[label_index]
)
)
print("\n\n")
@staticmethod
def _print_overall_stats_table(stats_overall, title="Overall statistics"):
"""
Helper for printing a human readable table for overall statistics
Returns:
None
"""
title_format = "{:>12}" * (len(stats_overall))
common_stats = ["accuracy", "f1_weighted", "tp", "tn", "fp", "fn"]
stat_row_format = (
"{:>12.3f}" * 2
+ "{:>12.0f}" * 4
+ "{:>12.3f}" * (len(stats_overall) - len(common_stats))
)
table_titles = common_stats + [
stat for stat in stats_overall.keys() if stat not in common_stats
]
print(title + ": \n")
print(title_format.format(*table_titles))
row = []
for stat in table_titles:
row.append(stats_overall[stat])
print(stat_row_format.format(*row))
print("\n\n")
@staticmethod
def _truncate_label(label, max_len):
return (label[:max_len] + "..") if len(label) > max_len else label
[docs]class StandardModelEvaluation(ModelEvaluation):
[docs] def raw_results(self):
"""Returns the raw results of the model evaluation"""
text_labels = []
predicted, expected = [], []
for result in self.results:
text_labels, predicted = self._update_raw_result(
result.predicted, text_labels, predicted
)
text_labels, expected = self._update_raw_result(
result.expected, text_labels, expected
)
return RawResults(
predicted=predicted, expected=expected, text_labels=text_labels
)
[docs] def get_stats(self):
"""Prints model evaluation stats in a table to stdout"""
raw_results = self.raw_results()
stats = self._get_common_stats(
raw_results.expected, raw_results.predicted, raw_results.text_labels
)
# Note can add any stats specific to the standard model to any of the tables here
return stats
[docs] def print_stats(self):
"""Prints model evaluation stats to stdout"""
raw_results = self.raw_results()
stats = self.get_stats()
self._print_overall_stats_table(stats["stats_overall"])
self._print_class_stats_table(stats["class_stats"], raw_results.text_labels)
self._print_class_matrix(stats["confusion_matrix"], raw_results.text_labels)
[docs]class SequenceModelEvaluation(ModelEvaluation):
def __init__(self, config, results):
self._tag_scheme = config.model_settings.get("tag_scheme", "IOB").upper()
super().__init__(config, results)
[docs] def raw_results(self):
"""Returns the raw results of the model evaluation"""
text_labels = []
predicted, expected = [], []
predicted_flat, expected_flat = [], []
for result in self.results:
raw_predicted = self.label_encoder.encode(
[result.predicted], examples=[result.example]
)[0]
raw_expected = self.label_encoder.encode(
[result.expected], examples=[result.example]
)[0]
vec = []
for entity in raw_predicted:
text_labels, vec = self._update_raw_result(entity, text_labels, vec)
predicted.append(vec)
predicted_flat.extend(vec)
vec = []
for entity in raw_expected:
text_labels, vec = self._update_raw_result(entity, text_labels, vec)
expected.append(vec)
expected_flat.extend(vec)
return RawResults(
predicted=predicted,
expected=expected,
text_labels=text_labels,
predicted_flat=predicted_flat,
expected_flat=expected_flat,
)
def _get_sequence_stats(self):
"""
TODO: Generate additional sequence level stats
"""
sequence_accuracy = self.get_accuracy()
return {"sequence_accuracy": sequence_accuracy}
@staticmethod
def _print_sequence_stats_table(sequence_stats):
"""
Helper for printing a human readable table for sequence statistics
Returns:
None
"""
title_format = "{:>18}" * (len(sequence_stats))
table_titles = ["sequence_accuracy"]
stat_row_format = "{:>18.3f}" * (len(sequence_stats))
print("Sequence-level statistics: \n")
print(title_format.format(*table_titles))
row = []
for stat in table_titles:
row.append(sequence_stats[stat])
print(stat_row_format.format(*row))
print("\n\n")
[docs] def get_stats(self):
"""Prints model evaluation stats in a table to stdout"""
raw_results = self.raw_results()
stats = self._get_common_stats(
raw_results.expected_flat,
raw_results.predicted_flat,
raw_results.text_labels,
)
sequence_stats = self._get_sequence_stats()
stats["sequence_stats"] = sequence_stats
# Note: can add any stats specific to the sequence model to any of the tables here
return stats
[docs] def print_stats(self):
"""Prints model evaluation stats to stdout"""
raw_results = self.raw_results()
stats = self.get_stats()
self._print_overall_stats_table(
stats["stats_overall"], "Overall tag-level statistics"
)
self._print_class_stats_table(
stats["class_stats"],
raw_results.text_labels,
"Tag-level statistics by class",
)
self._print_class_matrix(stats["confusion_matrix"], raw_results.text_labels)
self._print_sequence_stats_table(stats["sequence_stats"])
[docs]class EntityModelEvaluation(SequenceModelEvaluation):
"""Generates some statistics specific to entity recognition"""
def _get_entity_boundary_stats(self):
"""
Calculate le, be, lbe, tp, tn, fp, fn as defined here:
https://nlpers.blogspot.com/2006/08/doing-named-entity-recognition-dont.html
"""
boundary_counts = BoundaryCounts()
raw_results = self.raw_results()
for expected_sequence, predicted_sequence in zip(
raw_results.expected, raw_results.predicted
):
expected_seq_labels = [
raw_results.text_labels[i] for i in expected_sequence
]
predicted_seq_labels = [
raw_results.text_labels[i] for i in predicted_sequence
]
boundary_counts = get_boundary_counts(
expected_seq_labels, predicted_seq_labels, boundary_counts
)
return boundary_counts.to_dict()
@staticmethod
def _print_boundary_stats(boundary_counts):
title_format = "{:>12}" * (len(boundary_counts))
table_titles = boundary_counts.keys()
stat_row_format = "{:>12}" * (len(boundary_counts))
print("Segment-level statistics: \n")
print(title_format.format(*table_titles))
row = []
for stat in table_titles:
row.append(boundary_counts[stat])
print(stat_row_format.format(*row))
print("\n\n")
[docs] def get_stats(self):
stats = super().get_stats()
if self._tag_scheme == "IOB":
boundary_stats = self._get_entity_boundary_stats()
stats["boundary_stats"] = boundary_stats
return stats
[docs] def print_stats(self):
raw_results = self.raw_results()
stats = self.get_stats()
self._print_overall_stats_table(
stats["stats_overall"], "Overall tag-level statistics"
)
self._print_class_stats_table(
stats["class_stats"],
raw_results.text_labels,
"Tag-level statistics by class",
)
self._print_class_matrix(stats["confusion_matrix"], raw_results.text_labels)
if self._tag_scheme == "IOB":
self._print_boundary_stats(stats["boundary_stats"])
self._print_sequence_stats_table(stats["sequence_stats"])