Source code for mindmeld.models.nn_utils.layers

# -*- coding: utf-8 -*-
#
# Copyright (c) 2015 Cisco Systems, Inc. and others.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from typing import List, Union, Dict

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence

from .._util import _get_module_or_attr

nn_module = _get_module_or_attr("torch.nn", "Module")

logger = logging.getLogger(__name__)


[docs]class EmbeddingLayer(nn_module): """A pytorch wrapper layer for embeddings that takes input a batched sequence of ids and outputs embeddings corresponding to those ids """ def __init__( self, num_tokens: int, emb_dim: int, padding_idx: int = None, embedding_weights: Dict[int, Union[List, np.ndarray]] = None, update_embeddings: bool = True, embeddings_dropout: float = 0.5, coefficients: List[float] = None, update_coefficients: bool = True ): """ Args: num_tokens (int): size of the dictionary of embeddings emb_dim (int): the size of each embedding vector padding_idx (int, Optional): If given, pads the output with the embedding vector at `padding_idx` (initialized to zeros) whenever it encounters the index. embedding_weights (Dict[int, Union[List, np.ndarray]], Optional): weights to overwrite the already initialized embedding weights update_embeddings (bool, Optional): whether to freeze or train the embedding weights embeddings_dropout (float, Optional): dropout rate to apply on the forward call coefficients (List[float], Optional): weight coefficients for the dictionary of embeddings update_coefficients (bool, Optional): whether to freeze or train the coefficients """ super().__init__() self.embeddings = nn.Embedding(num_tokens, emb_dim, padding_idx=padding_idx) if embedding_weights is not None: if isinstance(embedding_weights, dict): # when weights are passed as dict with keys as indices and values as embeddings for idx, emb in embedding_weights.items(): self.embeddings.weight.data[idx] = torch.as_tensor(emb) msg = f"Initialized {len(embedding_weights)} number of embedding weights " \ f"from the embedder model" logger.info(msg) else: # when weights are passed as an array or tensor self.embeddings.load_state_dict({'weight': torch.as_tensor(embedding_weights)}) self.embeddings.weight.requires_grad = update_embeddings self.embedding_for_coefficients = None if coefficients is not None: if not len(coefficients) == num_tokens: msg = f"Length of coefficients ({len(coefficients)}) must match the number of " \ f"embeddings ({num_tokens})" raise ValueError(msg) self.embedding_for_coefficients = nn.Embedding(num_tokens, 1, padding_idx=padding_idx) self.embedding_for_coefficients.load_state_dict( {'weight': torch.as_tensor(coefficients).view(-1, 1)} ) self.embedding_for_coefficients.weight.requires_grad = update_coefficients self.dropout = nn.Dropout(embeddings_dropout)
[docs] def forward(self, padded_token_ids: "Tensor2d[int]") -> "Tensor3d[float]": # padded_token_ids: dim: [BS, SEQ_LEN] # returns: dim: [BS, SEQ_LEN, EMB_DIM] # [BS, SEQ_LEN] -> [BS, SEQ_LEN, EMB_DIM] outputs = self.embeddings(padded_token_ids) if self.embedding_for_coefficients: # [BS, SEQ_LEN] -> [BS, SEQ_LEN, 1] coefficients = self.embedding_for_coefficients(padded_token_ids) # [BS, SEQ_LEN, EMB_DIM] -> [BS, SEQ_LEN, EMB_DIM] outputs = torch.mul(outputs, coefficients) outputs = self.dropout(outputs) return outputs
[docs]class CnnLayer(nn_module): """A pytorch wrapper layer for 2D Convolutions """ def __init__(self, emb_dim: int, kernel_sizes: List[int], num_kernels: List[int]): """ Args: emb_dim (int): the size of embedding vectors or last dimension of hidden state prior to this CNN layer (i.e. width for convolution filters) kernel_sizes (List[int]): the length of each kernel provided as a list of lengths (i.e. length for convolution filters) num_kernels (List[int]): the number of kernels for each kernel size provided as a list of numbers (one number per provided size of kernel) """ super().__init__() if isinstance(num_kernels, list) and len(num_kernels) != len(kernel_sizes): # incorrect length of num_kernels list specified num_kernels = [num_kernels[0]] * len(kernel_sizes) elif isinstance(num_kernels, int) and num_kernels > 0: # num_kernels is a single integer value num_kernels = [num_kernels] * len(kernel_sizes) elif not isinstance(num_kernels, list): msg = f"Invalid value for num_kernels: {num_kernels}. " \ f"Expected a list of same length as emb_dim ({len(emb_dim)})" raise ValueError(msg) self.convs = nn.ModuleList() # Unsqueeze input dim [BS, SEQ_LEN, EMD_DIM] to [BS, 1, SEQ_LEN, EMDDIM] and send as input # Each conv module output's dimensions are [BS, n, SEQ_LEN, 1] for kernel_size, num_kernel in zip(kernel_sizes, num_kernels): self.convs.append( nn.Sequential( nn.Conv2d(1, num_kernel, (kernel_size, emb_dim), padding=(kernel_size - 1, 0), dilation=1, bias=True, padding_mode='zeros'), nn.ReLU(), ) )
[docs] def forward(self, padded_token_embs: "Tensor3d[float]") -> "Tensor2d[float]": # padded_token_embs: dim: [BS, SEQ_LEN, EMD_DIM] # returns: dim: [BS, EMB_DIM`] # [BS, SEQ_LEN, EMD_DIM] -> [BS, 1, SEQ_LEN, EMD_DIM] embs_unsqueezed = torch.unsqueeze(padded_token_embs, dim=1) # [BS, 1, SEQ_LEN, EMD_DIM] -> list([BS, n, SEQ_LEN]) conv_outputs = [conv(embs_unsqueezed).squeeze(3) for conv in self.convs] # list([BS, n, SEQ_LEN]) -> list([BS, n]) maxpool_conv_outputs = [F.max_pool1d(out, out.size(2)).squeeze(2) for out in conv_outputs] # list([BS, n]) -> [BS, sum(n)] outputs = torch.cat(maxpool_conv_outputs, dim=1) return outputs
[docs]class LstmLayer(nn_module): """A pytorch wrapper layer for BiLSTMs """ def __init__( self, emb_dim: int, hidden_dim: int, num_layers: int, lstm_dropout: float, bidirectional: bool ): """ Args: emb_dim (int): the size of embedding vectors or last dimension of hidden state prior to this LSTM layer hidden_dim (int): the hidden dimension for nn.LSTM num_layers (int): the number of nn.LSTM layers to stack lstm_dropout (float): the dropout rate for nn.LSTM bidirectional (bool): whether LSTMs should be applied on both forward and backward sequences of the input or not """ super().__init__() self.lstm = nn.LSTM( emb_dim, hidden_dim, num_layers=num_layers, dropout=lstm_dropout, bidirectional=bidirectional, batch_first=True )
[docs] def forward( self, padded_token_embs: "Tensor3d[float]", lengths: "Tensor1d[int]", ) -> "Tensor3d[float]": # padded_token_embs: dim: [BS, SEQ_LEN, EMD_DIM] # lengths: dim: [BS] # returns: dim: [BS, SEQ_LEN, EMB_DIM] # [BS, SEQ_LEN, EMD_DIM] -> [BS, SEQ_LEN, EMD_DIM*(2 if bidirectional else 1)] lengths = lengths.to(torch.device("cpu")) packed = pack_padded_sequence(padded_token_embs, lengths, batch_first=True, enforce_sorted=False) lstm_outputs, _ = self.lstm(packed) outputs = pad_packed_sequence(lstm_outputs, batch_first=True)[0] return outputs
[docs]class PoolingLayer(nn_module): """A pooling layer for Tensor3d objects that pools along the last dimension. Assumes that padding if any exists on the right side of inputs (i.e. not in the beginning of inputs) """ def __init__(self, pooling_type: str): """ Args: pooling_type (str): the choice of pooling; to be amongst following: first: the first index of each sequence will be the pooled output (similar to CLS token in BERT models) last: the last index of each sequence will be the pooled output (useful for pooling outputs from nn.LSTM) max: max pool across last dimension will be the pooled output mean: mean pool across last dimension will be the pooled output mean_sqrt: similar to 'mean' but slashed by the square root of sequence length """ super().__init__() pooling_type = pooling_type.lower() allowed_pooling_types = ["first", "last", "max", "mean", "mean_sqrt"] if pooling_type not in allowed_pooling_types: msg = f"Expected pooling_type amongst {allowed_pooling_types} " \ f"but found '{pooling_type}'" raise ValueError(msg) # assumption: first token is never a pad token for the passed inputs self._requires_length = ["last", "max", "mean", "mean_sqrt"] self.pooling_type = pooling_type
[docs] def forward( self, padded_token_embs: "Tensor3d[float]", lengths: "Tensor1d[int]" = None, ) -> "Tensor2d[float]": # padded_token_embs: dim: [BS, SEQ_LEN, EMD_DIM] # lengths: dim: [BS] # returns: dim: [BS, EMD_DIM] if self.pooling_type in self._requires_length and lengths is None: msg = f"Missing required value 'lengths' for pooling_type: {self.pooling_type}" raise ValueError(msg) if self.pooling_type == "first": outputs = padded_token_embs[:, 0, :] elif self.pooling_type == "last": last_seq_idxs = torch.LongTensor([x - 1 for x in lengths]) outputs = padded_token_embs[range(padded_token_embs.shape[0]), last_seq_idxs, :] else: try: target_device = padded_token_embs.device mask = pad_sequence( [torch.as_tensor([1] * length_) for length_ in lengths], batch_first=True, padding_value=0.0, ).unsqueeze(-1).expand(padded_token_embs.size()).float().to(target_device) except RuntimeError as e: msg = f"Unable to create a mask for '{self.pooling_type}' pooling operation in " \ f"{self.__class__.__name__}. It is possible that your choice of tokenizer " \ f"does not split input text at whitespace (eg. robert-base tokenizer), due " \ f"to which tokenization of a word is different between with and without " \ f"context. If working with a transformers model, consider changing the " \ f"pretrained model name and restart training." raise ValueError(msg) from e if self.pooling_type == "max": padded_token_embs[mask == 0] = -1e9 # set to a large negative value outputs, _ = torch.max(padded_token_embs, dim=1) elif self.pooling_type == "mean": summed_padded_token_embs = torch.sum(padded_token_embs * mask, dim=1) outputs = summed_padded_token_embs / mask.sum(1) elif self.pooling_type == "mean_sqrt": summed_padded_token_embs = torch.sum(padded_token_embs * mask, dim=1) expanded_lengths = lengths.unsqueeze(dim=1).expand(summed_padded_token_embs.size()) outputs = torch.div(summed_padded_token_embs, torch.sqrt(expanded_lengths.float())) return outputs
[docs]class SplittingAndPoolingLayer(nn_module): """Pooling class that first splits a sequence of representations into subgroups of representations based on lengths of subgroups inputted, and pools each subgroup separately. """ def __init__(self, pooling_type: str, number_of_terminal_tokens: int): """ Args: pooling_type (str): the choice of pooling; to be amongst following: first: the first index of each subsequence will be the pooled output of that subgroup(e.g. token classification using BERT models with sub-word tokenization) last: the last index of each subsequence will be the pooled output of that subgroup (e.g. for word level representations when using a character BiLSTM) max: max pool across subsequence will be the pooled output of that subgroup mean: mean pool across subsequence will be the pooled output of that subgroup (e.g. token classification using BERT models with sub-word tokenization) mean_sqrt: similar to 'mean' but slashed by the square root of subsequence length number_of_terminal_tokens (int): the number of terminal tokens that will be discarded if discard terminals is set to True in the forward method. """ super().__init__() self.pooling_type = pooling_type.lower() self.number_of_terminal_tokens = number_of_terminal_tokens self.pooling_layer = PoolingLayer(pooling_type=self.pooling_type) def _split_and_pool( self, tensor_2d: "Tensor2d[float]", list_of_subgroup_lengths: "Tensor1d[int]", discard_terminals: bool ): # tensor_2d: dim: [SEQ_LEN, EMD_DIM] # list_of_subgroup_lengths: dim: List of int summing up to SEQ_LEN' <= SEQ_LEN # discard_terminals: bool # returns: dim: [SEQ_LEN``, EMD_DIM] if discard_terminals: # TODO: Number of terminals can also be 1 (maybe just left or just right) in some models if self.number_of_terminal_tokens != 2: msg = f"Unable to combine sub-tokens' representations for each word into one in " \ f"{self.__class__.__name__}. It is possible that your choice of tokenizer " \ f"has {self.number_of_terminal_tokens} terminal token instead of assumed " \ f"2 terminals." # (eg. t5-base tokenizer) raise NotImplementedError(msg) # since list_of_subgroup_lengths consists of lengths of only non-terminal subgroups but # the inputted tensor_2d consists of terminals seq_len_required = sum(list_of_subgroup_lengths) + self.number_of_terminal_tokens tensor_2d_with_terminals = tensor_2d[:seq_len_required] tensor_2d = tensor_2d_with_terminals[1:-1] # discard terminal representations else: seq_len_required = sum(list_of_subgroup_lengths) tensor_2d = tensor_2d[:seq_len_required] try: # argument 'split_sizes' (position 1) must be tuple of ints, not Tensor splits = torch.split(tensor_2d, list_of_subgroup_lengths.tolist(), dim=0) except RuntimeError as e: msg = f"Unable to combine sub-tokens' representations for each word into one in " \ f"{self.__class__.__name__}. It is possible that your choice of tokenizer " \ f"does not split input text at whitespace (eg. robert-base tokenizer), due " \ f"to which one-representation-per-word cannot be obtained to do tagging at " \ f"word-level for token classification." raise ValueError(msg) from e padded_token_embs = pad_sequence(splits, batch_first=True) # [BS', SEQ_LEN', EMD_DIM] # return dims: [len(list_of_subgroup_lengths), EMD_DIM] pooled_repr_for_each_subgroup = self.pooling_layer( padded_token_embs=padded_token_embs, lengths=list_of_subgroup_lengths ) return pooled_repr_for_each_subgroup
[docs] def forward( self, padded_token_embs: "Tensor3d[float]", span_lengths: "List[Tensor1d[int]]", discard_terminals: bool = None ): # padded_token_embs: dim: [BS, SEQ_LEN, EMD_DIM] # span_lengths: dim: List[List of int summing up to SEQ_LEN' <= SEQ_LEN] # discard_terminals: bool # returns: dim: [BS, SEQ_LEN', EMD_DIM] outputs = pad_sequence([ self._split_and_pool(_padded_token_embs, _span_lengths, discard_terminals) for _padded_token_embs, _span_lengths in zip(padded_token_embs, span_lengths) ], batch_first=True) return outputs