# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Custom modules built on top of nn layers that can do token classification
"""
import logging
from abc import abstractmethod
from collections import OrderedDict
from typing import List
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torchcrf import CRF
from .classification import BaseClassification
from .helpers import TokenClassificationType
from .helpers import TokenizerType, EmbedderType
from .layers import (
EmbeddingLayer,
CnnLayer,
LstmLayer,
PoolingLayer,
SplittingAndPoolingLayer
)
from ..containers import HuggingfaceTransformersContainer
logger = logging.getLogger(__name__)
[docs]class BaseTokenClassification(BaseClassification):
"""Base class that defines all the necessary elements to successfully train/infer
custom pytorch modules wrapped on top of this base class. Classes derived from
this base can be trained for sequence tagging aka. token classification.
"""
@property
def classification_type(self):
return "tagger"
def _prepare_labels(self, labels: List[List[int]], max_length: int):
# for token classification, the length of an example matters (i.e. the number of
# (sub-)words in it) as the labels/targets are one-to-one mapped per (sub-)word.
# hence, we need to do padding with a _label_padding_idx
def _trim_or_pad(label_sequence):
if len(label_sequence) > max_length:
return label_sequence[:max_length]
else:
return (
label_sequence + [self.params._label_padding_idx] *
(max_length - len(label_sequence))
)
return torch.as_tensor(
[_trim_or_pad(_label_sequence) for _label_sequence in labels],
dtype=torch.long
)
def _init_graph(self):
self._init_core()
# init the underlying params and architectural components
try:
assert self.out_dim > 0
self.params.update({"out_dim": self.out_dim})
except (AttributeError, AssertionError) as e:
msg = f"Derived class '{self.name}' must indicate its hidden size for dense layer " \
f"classification by having an attribute 'self.out_dim', which must be a " \
f"positive integer greater than 1"
raise ValueError(msg) from e
# init the peripheral architecture params and architectural components
self.span_pooling_layer = SplittingAndPoolingLayer(
self.params.token_spans_pooling_type,
self.encoder.number_of_terminal_tokens, # arg max number of terminal tokens
)
if not self.params.num_labels:
msg = f"Invalid number of labels ({self.params.num_labels}) inputted to '{self.name}'"
raise ValueError(msg)
# init the peripheral architectural components and the criterion to compute loss
self.dense_layer_dropout = nn.Dropout(
p=1 - self.params.output_keep_prob
)
if self.params.use_crf_layer:
self.classifier_head = nn.Linear(self.out_dim, self.params.num_labels)
self.crf_layer = CRF(self.params.num_labels, batch_first=True)
else:
if self.params.num_labels >= 2:
# cross-entropy criterion
self.classifier_head = nn.Linear(self.out_dim, self.params.num_labels)
self.criterion = nn.CrossEntropyLoss(
reduction='mean', ignore_index=self.params._label_padding_idx)
else:
msg = f"Invalid number of labels specified: {self.params.num_labels}. " \
f"A valid number is equal to or greater than 2"
raise ValueError(msg)
msg = f"{self.name} is initialized"
logger.info(msg)
[docs] def forward(self, batch_data):
batch_data = self.to_device(batch_data)
batch_data = self._forward_core(batch_data)
token_embs = batch_data["token_embs"]
# strip the terminals if required; note they are not accounted in the label padding process
token_embs = self.span_pooling_layer(
token_embs,
batch_data["split_lengths"], # List[Tensor1d[int]],
discard_terminals=self.params.add_terminals
) # [BS, SEQ_LEN`, EMD_DIM], and the SEQ_LEN` matches the width of labels upon padding
token_embs = self.dense_layer_dropout(token_embs)
logits = self.classifier_head(token_embs)
batch_data.update({"logits": logits})
targets = batch_data.pop("_labels", None)
if targets is not None:
if self.params.use_crf_layer:
# create a mask to ignore token positions that are padded
mask = torch.as_tensor(targets != self.params._label_padding_idx,
dtype=torch.uint8).to(self.params.device)
loss = - self.crf_layer(logits, targets, mask=mask) # negative log likelihood
else:
loss = self.criterion(logits.view(-1, logits.shape[-1]), targets.view(-1))
batch_data.update({"loss": loss})
return batch_data
[docs] def predict(self, examples):
predictions = []
was_training = self.training
self.eval()
with torch.no_grad():
for start_idx in range(0, len(examples), self.params.batch_size):
batch_examples = examples[start_idx:start_idx + self.params.batch_size]
batch_data = self.encoder.batch_encode(
batch_examples,
padding_length=self.params.padding_length,
**({'add_terminals': self.params.add_terminals}
if self.params.add_terminals is not None else {})
)
batch_logits = self.forward(batch_data)["logits"]
# find predictions
if self.params.use_crf_layer:
# create a mask to ignore token positions that are padded
mask = pad_sequence(
[torch.as_tensor([1] * len(_split_lengths))
for _split_lengths in batch_data["split_lengths"]],
batch_first=True
)
# [BS] -> [BS, SEQ_LEN]
mask = torch.as_tensor(mask, dtype=torch.uint8).to(self.params.device)
batch_predictions = self.crf_layer.decode(batch_logits, mask=mask)
else:
batch_predictions = torch.argmax(batch_logits, dim=-1).tolist()
# trim predictions as per sequence length
batch_predictions = [
predictions[:len(_split_lengths)] for predictions, _split_lengths in
zip(batch_predictions, batch_data["split_lengths"])
]
predictions.extend(batch_predictions)
if was_training:
self.train()
return predictions
[docs] def predict_proba(self, examples):
prediction_tuples = []
was_training = self.training
self.eval()
with torch.no_grad():
for start_idx in range(0, len(examples), self.params.batch_size):
batch_examples = examples[start_idx:start_idx + self.params.batch_size]
batch_data = self.encoder.batch_encode(
batch_examples,
padding_length=self.params.padding_length,
**({'add_terminals': self.params.add_terminals}
if self.params.add_terminals is not None else {})
)
batch_logits = self.forward(batch_data)["logits"]
# find prediction probabilities
if self.params.use_crf_layer:
msg = f"Prediction probabilities cannot be computed when the param " \
f"use_crf_layer is set to True in {self.__class__.__name__}"
raise NotImplementedError(msg)
else:
batch_maxes = torch.max(batch_logits, dim=-1)
batch_predictions = batch_maxes.indices.tolist()
batch_probabilities = batch_maxes.values.tolist()
for preds, probs, _split_lengths in zip(
batch_predictions, batch_probabilities, batch_data["split_lengths"]
):
prediction_tuples.append(list(zip(preds, probs))[:len(_split_lengths)])
if was_training:
self.train()
return prediction_tuples
@abstractmethod
def _init_core(self) -> None:
raise NotImplementedError
@abstractmethod
def _forward_core(self, batch_data):
raise NotImplementedError
[docs]class EmbedderForTokenClassification(BaseTokenClassification):
def _init_core(self):
self.emb_layer = EmbeddingLayer(
self.params._num_tokens,
self.params.emb_dim,
self.params._padding_idx,
self.params.pop("_embedding_weights", None),
self.params.update_embeddings,
1 - self.params.embedder_output_keep_prob
)
self.out_dim = self.params.emb_dim
def _forward_core(self, batch_data):
seq_ids = batch_data["seq_ids"] # [BS, SEQ_LEN]
token_embs = self.emb_layer(seq_ids) # [BS, SEQ_LEN, self.out_dim]
batch_data.update({"token_embs": token_embs})
return batch_data
[docs]class LstmForTokenClassification(BaseTokenClassification):
"""A LSTM module that operates on a batched sequence of token ids. The tokens could be
characters or words or sub-words. This module uses an additional input that determines
how the sequence of embeddings obtained after the LSTM layers for each instance in the
batch, needs to be split. Once split, the sub-groups of embeddings (each sub-group
corresponding to a word or a phrase) can be collapsed to 1D representation per sub-group
through pooling operations. Finally, this module outputs a 2D representation for each
instance in the batch (i.e. [BS, SEQ_LEN', EMB_DIM]).
"""
def _init_core(self):
self.emb_layer = EmbeddingLayer(
self.params._num_tokens,
self.params.emb_dim,
self.params._padding_idx,
self.params.pop("_embedding_weights", None),
self.params.update_embeddings,
1 - self.params.embedder_output_keep_prob
)
self.lstm_layer = LstmLayer(
self.params.emb_dim,
self.params.lstm_hidden_dim,
self.params.lstm_num_layers,
1 - self.params.lstm_keep_prob,
self.params.lstm_bidirectional
)
self.out_dim = (
self.params.lstm_hidden_dim * 2 if self.params.lstm_bidirectional
else self.params.lstm_hidden_dim
)
def _forward_core(self, batch_data):
seq_ids = batch_data["seq_ids"] # [BS, SEQ_LEN]
summed_split_lengths = [
sum(_split_lengths) +
(self.encoder.number_of_terminal_tokens if self.params.add_terminals else 0)
for _split_lengths in batch_data["split_lengths"]
]
summed_split_lengths = torch.as_tensor(summed_split_lengths, dtype=torch.long) # [BS]
token_embs = self.emb_layer(seq_ids) # [BS, SEQ_LEN, EMD_DIM]
token_embs = self.lstm_layer(
token_embs, summed_split_lengths) # [BS, SEQ_LEN, self.out_dim]
batch_data.update({"token_embs": token_embs})
return batch_data
[docs]class CharLstmWithWordLstmForTokenClassification(BaseTokenClassification):
def _prepare_input_encoder(self, examples, **params):
if (params.get("tokenizer_type") and
TokenizerType(params.get("tokenizer_type")) !=
TokenizerType.WHITESPACE_AND_CHAR_DUAL_TOKENIZER
):
msg = f"Param 'tokenizer_type' must be " \
f"{TokenizerType.WHITESPACE_AND_CHAR_DUAL_TOKENIZER.value} for " \
f"{self.__class__.__name__}."
raise ValueError(msg)
params = super()._prepare_input_encoder(examples, **params)
# update params without which can become ambiguous when loading a model
params.update({
"_char_num_tokens": len(self.encoder.get_char_vocab()),
"_char_padding_idx": self.encoder.get_char_pad_token_idx(),
"char_add_terminals": params.get("char_add_terminals"),
"char_padding_length": params.get("char_padding_length"),
"char_emb_dim": params.get("char_emb_dim"),
})
return params
def _init_core(self):
self.char_proj_dim = self.params.get("char_proj_dim") or self.params.emb_dim
self.params.update({"char_proj_dim": self.char_proj_dim})
self.char_emb_layer = EmbeddingLayer(
self.params._char_num_tokens,
self.params.char_emb_dim,
self.params._char_padding_idx,
self.params.pop("_char_embedding_weights", None),
self.params.update_embeddings
)
self.char_lstm_layer = LstmLayer(
self.params.char_emb_dim,
self.params.char_lstm_hidden_dim,
self.params.char_lstm_num_layers,
1 - self.params.char_lstm_keep_prob,
self.params.char_lstm_bidirectional
)
self.char_dropout = nn.Dropout(
p=1 - self.params.char_lstm_keep_prob
)
self.char_lstm_output_pooling = PoolingLayer(
self.params.char_lstm_output_pooling_type
)
char_out_dim = (
self.params.char_lstm_hidden_dim * 2 if self.params.char_lstm_bidirectional
else self.params.char_lstm_hidden_dim
)
self.char_lstm_output_transform = nn.Linear(char_out_dim, self.params.char_proj_dim)
self.emb_layer = EmbeddingLayer(
self.params._num_tokens,
self.params.emb_dim,
self.params._padding_idx,
self.params.pop("_embedding_weights", None),
self.params.update_embeddings,
1 - self.params.embedder_output_keep_prob
)
self.lstm_layer = LstmLayer(
self.params.emb_dim + self.params.char_proj_dim,
self.params.lstm_hidden_dim,
self.params.lstm_num_layers,
1 - self.params.lstm_keep_prob,
self.params.lstm_bidirectional
)
self.out_dim = (
self.params.lstm_hidden_dim * 2 if self.params.lstm_bidirectional
else self.params.lstm_hidden_dim
)
def _forward_core(self, batch_data):
char_seq_ids = batch_data["char_seq_ids"] # List of [BS, SEQ_LEN]
char_seq_lengths = batch_data["char_seq_lengths"] # List of [BS]
encs = [self.char_emb_layer(_seq_ids) for _seq_ids in char_seq_ids]
encs = [self.char_lstm_layer(enc, seq_len) for enc, seq_len in zip(encs, char_seq_lengths)]
encs = [self.char_dropout(enc) for enc in encs]
encs = [self.char_lstm_output_pooling(enc, seq_len)
for enc, seq_len in zip(encs, char_seq_lengths)]
encs = pad_sequence(encs, batch_first=True) # [BS, SEQ_LEN, char_out_dim]
char_encs = self.char_lstm_output_transform(encs) # [BS, SEQ_LEN, self.char_proj_dim]
summed_split_lengths = [
sum(_split_lengths) +
(self.encoder.number_of_terminal_tokens if self.params.add_terminals else 0)
for _split_lengths in batch_data["split_lengths"]
]
summed_split_lengths = torch.as_tensor(summed_split_lengths, dtype=torch.long) # [BS]
seq_ids = batch_data["seq_ids"] # [BS, SEQ_LEN]
word_encs = self.emb_layer(seq_ids) # [BS, SEQ_LEN, self.emb_dim]
char_plus_word_encs = torch.cat((char_encs, word_encs), dim=-1) # [BS, SEQ_LEN, sum(both)]
token_embs = self.lstm_layer(
char_plus_word_encs, summed_split_lengths) # [BS, SEQ_LEN, self.out_dim]
batch_data.update({"token_embs": token_embs})
return batch_data
[docs]class CharCnnWithWordLstmForTokenClassification(BaseTokenClassification):
def _prepare_input_encoder(self, examples, **params):
if (params.get("tokenizer_type") and
TokenizerType(params.get("tokenizer_type")) !=
TokenizerType.WHITESPACE_AND_CHAR_DUAL_TOKENIZER
):
msg = f"Param 'tokenizer_type' must be " \
f"{TokenizerType.WHITESPACE_AND_CHAR_DUAL_TOKENIZER.value} for " \
f"{self.__class__.__name__}."
raise ValueError(msg)
params = super()._prepare_input_encoder(examples, **params)
# update params without which can become ambiguous when loading a model
params.update({
"_char_num_tokens": len(self.encoder.get_char_vocab()),
"_char_padding_idx": self.encoder.get_char_pad_token_idx(),
"char_add_terminals": params.get("char_add_terminals"),
"char_padding_length": params.get("char_padding_length"),
"char_emb_dim": params.get("char_emb_dim"),
})
return params
def _init_core(self):
self.char_proj_dim = self.params.get("char_proj_dim") or self.params.emb_dim
self.params.update({"char_proj_dim": self.char_proj_dim})
self.char_emb_layer = EmbeddingLayer(
self.params._char_num_tokens,
self.params.char_emb_dim,
self.params._char_padding_idx,
self.params.pop("_char_embedding_weights", None),
self.params.update_embeddings
)
self.char_conv_layer = CnnLayer(
self.params.char_emb_dim,
self.params.char_window_sizes,
self.params.char_number_of_windows
)
self.char_dropout = nn.Dropout(
p=1 - self.params.char_cnn_output_keep_prob
)
char_out_dim = sum(self.params.char_number_of_windows)
self.char_cnn_output_transform = nn.Linear(char_out_dim, self.params.char_proj_dim)
self.emb_layer = EmbeddingLayer(
self.params._num_tokens,
self.params.emb_dim,
self.params._padding_idx,
self.params.pop("_embedding_weights", None),
self.params.update_embeddings,
1 - self.params.embedder_output_keep_prob
)
self.lstm_layer = LstmLayer(
self.params.emb_dim + self.params.char_proj_dim,
self.params.lstm_hidden_dim,
self.params.lstm_num_layers,
1 - self.params.lstm_keep_prob,
self.params.lstm_bidirectional
)
self.out_dim = (
self.params.lstm_hidden_dim * 2 if self.params.lstm_bidirectional
else self.params.lstm_hidden_dim
)
def _forward_core(self, batch_data):
char_seq_ids = batch_data["char_seq_ids"] # List of [BS, SEQ_LEN]
encs = [self.char_emb_layer(_seq_ids) for _seq_ids in char_seq_ids]
encs = [self.char_conv_layer(enc) for enc in encs]
encs = [self.char_dropout(enc) for enc in encs]
encs = pad_sequence(encs, batch_first=True) # [BS, SEQ_LEN, sum(self.number_of_windows)]
char_encs = self.char_cnn_output_transform(encs) # [BS, SEQ_LEN, self.char_proj_dim]
summed_split_lengths = [
sum(_split_lengths) +
(self.encoder.number_of_terminal_tokens if self.params.add_terminals else 0)
for _split_lengths in batch_data["split_lengths"]
]
summed_split_lengths = torch.as_tensor(summed_split_lengths, dtype=torch.long) # [BS]
seq_ids = batch_data["seq_ids"] # [BS, SEQ_LEN]
word_encs = self.emb_layer(seq_ids) # [BS, SEQ_LEN, self.emb_dim]
char_plus_word_encs = torch.cat((char_encs, word_encs), dim=-1) # [BS, SEQ_LEN, sum(both)]
token_embs = self.lstm_layer(
char_plus_word_encs, summed_split_lengths) # [BS, SEQ_LEN, self.out_dim]
batch_data.update({"token_embs": token_embs})
return batch_data
[docs]class BertForTokenClassification(BaseTokenClassification):
[docs] def fit(self, examples, labels, **params):
# overriding base class' method to set params, and then calling base class' .fit()
embedder_type = params.get("embedder_type", EmbedderType.BERT.value)
if EmbedderType(embedder_type) != EmbedderType.BERT:
msg = f"{self.name} can only be used with 'embedder_type': " \
f"'{EmbedderType.BERT.value}'. " \
f"Other values passed through config params are not allowed."
raise ValueError(msg)
safe_values = {
"num_warmup_steps": 50,
"learning_rate": 2e-5,
"optimizer": "AdamW",
"max_grad_norm": 1.0
}
for k, v in safe_values.items():
v_inputted = params.get(k, v)
if v != v_inputted:
msg = f"{self.name} can be best used with '{k}' equal to '{v}' but found " \
f"the value '{v_inputted}'. Use the non-default value with caution as it " \
f"may lead to unexpected results and longer training times depending on " \
f"the choice of pretrained model."
logger.warning(msg)
else:
params.update({k: v})
params.update({
"embedder_type": embedder_type,
"save_frozen_embedder": params.get("save_frozen_embedder", False) # if True,
# frozen set of bert weights are also dumped, else they are skipped as they are not
# tuned and anyway frozen during training.
})
super().fit(examples, labels, **params)
def _create_optimizer(self):
params = list(self.named_parameters())
no_decay = ["bias", 'LayerNorm.bias', "LayerNorm.weight",
'layer_norm.bias', 'layer_norm.weight']
optimizer_grouped_parameters = [
{'params': [p for n, p in params if not any(nd in n for nd in no_decay)],
'weight_decay': 0.01},
{'params': [p for n, p in params if any(nd in n for nd in no_decay)],
'weight_decay': 0.0}
]
optimizer = getattr(torch.optim, self.params.optimizer)(
optimizer_grouped_parameters,
lr=self.params.learning_rate,
eps=1e-08,
weight_decay=0.01
)
return optimizer
def _create_optimizer_and_scheduler(self, num_training_steps):
num_warmup_steps = min(int(0.1 * num_training_steps), self.params.num_warmup_steps)
self.params.update({"num_warmup_steps": num_warmup_steps})
# https://github.com/huggingface/transformers/blob/master/src/transformers/optimization.py
# refer `get_linear_schedule_with_warmup` method
def lr_lambda(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(
0.0, float(num_training_steps - current_step) / float(
max(1, num_training_steps - num_warmup_steps))
)
# load a torch optimizer
optimizer = self._create_optimizer()
# load a lr scheduler
scheduler = getattr(torch.optim.lr_scheduler, "LambdaLR")(optimizer, lr_lambda)
return optimizer, scheduler
def _get_dumpable_state_dict(self):
if not self.params.update_embeddings and not self.params.save_frozen_embedder:
state_dict = OrderedDict(
{k: v for k, v in self.state_dict().items() if not k.startswith("bert_model")}
)
return state_dict
return self.state_dict()
def _init_core(self):
self.bert_model = HuggingfaceTransformersContainer(
self.params.pretrained_model_name_or_path,
cache_lookup=False
).get_transformer_model()
if not self.params.update_embeddings:
for param in self.bert_model.parameters():
param.requires_grad = False
self.dropout = nn.Dropout(
p=1 - self.params.embedder_output_keep_prob
)
self.out_dim = self.params.emb_dim
def _forward_core(self, batch_data):
# refer to https://huggingface.co/docs/transformers/master/en/main_classes/output
# for more details on huggingface's bert outputs
bert_outputs = self.bert_model(**batch_data["hgf_encodings"], return_dict=True)
# 'last_hidden_state' refers to the tensor output of the final transformer layer (i.e.
# before the logit layer) and hence its dimension [BS, SEQ_LEN, EMD_DIM]
last_hidden_state = bert_outputs.get("last_hidden_state") # [BS, SEQ_LEN, EMD_DIM]
if last_hidden_state is None:
msg = f"The choice of pretrained bert model " \
f"({self.params.pretrained_model_name_or_path}) " \
f"has no key 'last_hidden_state' in its output dictionary."
raise ValueError(msg)
last_hidden_state = self.dropout(last_hidden_state)
batch_data.update({"token_embs": last_hidden_state})
return batch_data
[docs]def get_token_classifier_cls(classifier_type: str, embedder_type: str = None):
try:
classifier_type = TokenClassificationType(classifier_type)
except ValueError as e:
msg = f"Neural Nets' token classification module expects classifier_type to be amongst" \
f" {[v.value for v in TokenClassificationType.__members__.values()]}" \
f" but found '{classifier_type}'."
raise ValueError(msg) from e
try:
embedder_type = EmbedderType(embedder_type)
except ValueError as e:
msg = f"Neural Nets' token classification module expects embedder_type to be amongst" \
f" {[v.value for v in EmbedderType.__members__.values()]} " \
f" but found '{embedder_type}'."
raise ValueError(msg) from e
if (
embedder_type == EmbedderType.BERT and
classifier_type not in [TokenClassificationType.EMBEDDER]
):
msg = f"To use the embedder_type '{EmbedderType.BERT.value}', " \
f"classifier_type must be '{TokenClassificationType.EMBEDDER.value}'."
raise ValueError(msg)
# disambiguation between glove, bert and non-pretrained embedders
def _resolve_and_return_embedder_class(_embedder_type):
return {
EmbedderType.NONE: EmbedderForTokenClassification,
EmbedderType.GLOVE: EmbedderForTokenClassification,
EmbedderType.BERT: BertForTokenClassification
}[_embedder_type]
return {
TokenClassificationType.EMBEDDER: _resolve_and_return_embedder_class(embedder_type),
TokenClassificationType.LSTM: LstmForTokenClassification,
TokenClassificationType.CNN_LSTM: CharCnnWithWordLstmForTokenClassification,
TokenClassificationType.LSTM_LSTM: CharLstmWithWordLstmForTokenClassification,
}[classifier_type]