Spaces:
Sleeping
Sleeping
| """ | |
| This file provides an implementation of several different neural network modules that are used for merging and | |
| transforming input data in various ways. The following components can be used when we are dealing with | |
| data from multiple modes, or when we need to merge multiple intermediate embedded representations in | |
| the forward process of a model. | |
| The main classes defined in this code are: | |
| - BilinearGeneral: This class implements a bilinear transformation layer that applies a bilinear transformation to | |
| incoming data, as described in the "Multiplicative Interactions and Where to Find Them", published at ICLR 2020, | |
| https://openreview.net/forum?id=rylnK6VtDH. The transformation involves two input features and an output | |
| feature, and also includes an optional bias term. | |
| - TorchBilinearCustomized: This class implements a bilinear layer similar to the one provided by PyTorch | |
| (torch.nn.Bilinear), but with additional customizations. This class can be used as an alternative to the | |
| BilinearGeneral class. | |
| - TorchBilinear: This class is a simple wrapper around the PyTorch's built-in nn.Bilinear module. It provides the | |
| same functionality as PyTorch's nn.Bilinear but within the structure of the current module. | |
| - FiLM: This class implements a Feature-wise Linear Modulation (FiLM) layer. FiLM layers apply an affine | |
| transformation to the input data, conditioned on some additional context information. | |
| - GatingType: This is an enumeration class that defines different types of gating mechanisms that can be used in | |
| the modules. | |
| - SumMerge: This class provides a simple summing mechanism to merge input streams. | |
| - VectorMerge: This class implements a more complex merging mechanism for vector streams. | |
| The streams are first transformed using layer normalization, a ReLU activation, and a linear layer. | |
| Then they are merged either by simple summing or by using a gating mechanism. | |
| The implementation of these classes involves PyTorch and Numpy libraries, and the classes use PyTorch's nn.Module as | |
| the base class, making them compatible with PyTorch's neural network modules and functionalities. | |
| These modules can be useful building blocks in more complex deep learning architectures. | |
| """ | |
| import enum | |
| import math | |
| from collections import OrderedDict | |
| from typing import List, Dict, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| class BilinearGeneral(nn.Module): | |
| """ | |
| Overview: | |
| Bilinear implementation as in: Multiplicative Interactions and Where to Find Them, | |
| ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, in1_features: int, in2_features: int, out_features: int): | |
| """ | |
| Overview: | |
| Initialize the Bilinear layer. | |
| Arguments: | |
| - in1_features (:obj:`int`): The size of each first input sample. | |
| - in2_features (:obj:`int`): The size of each second input sample. | |
| - out_features (:obj:`int`): The size of each output sample. | |
| """ | |
| super(BilinearGeneral, self).__init__() | |
| # Initialize the weight matrices W and U, and the bias vectors V and b | |
| self.W = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) | |
| self.U = nn.Parameter(torch.Tensor(out_features, in2_features)) | |
| self.V = nn.Parameter(torch.Tensor(out_features, in1_features)) | |
| self.b = nn.Parameter(torch.Tensor(out_features)) | |
| self.in1_features = in1_features | |
| self.in2_features = in2_features | |
| self.out_features = out_features | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| """ | |
| Overview: | |
| Initialize the parameters of the Bilinear layer. | |
| """ | |
| stdv = 1. / np.sqrt(self.in1_features) | |
| self.W.data.uniform_(-stdv, stdv) | |
| self.U.data.uniform_(-stdv, stdv) | |
| self.V.data.uniform_(-stdv, stdv) | |
| self.b.data.uniform_(-stdv, stdv) | |
| def forward(self, x: torch.Tensor, z: torch.Tensor): | |
| """ | |
| Overview: | |
| compute the bilinear function. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The first input tensor. | |
| - z (:obj:`torch.Tensor`): The second input tensor. | |
| """ | |
| # Compute the bilinear function | |
| # x^TWz | |
| out_W = torch.einsum('bi,kij,bj->bk', x, self.W, z) | |
| # x^TU | |
| out_U = z.matmul(self.U.t()) | |
| # Vz | |
| out_V = x.matmul(self.V.t()) | |
| # x^TWz + x^TU + Vz + b | |
| out = out_W + out_U + out_V + self.b | |
| return out | |
| class TorchBilinearCustomized(nn.Module): | |
| """ | |
| Overview: | |
| Customized Torch Bilinear implementation. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, in1_features: int, in2_features: int, out_features: int): | |
| """ | |
| Overview: | |
| Initialize the Bilinear layer. | |
| Arguments: | |
| - in1_features (:obj:`int`): The size of each first input sample. | |
| - in2_features (:obj:`int`): The size of each second input sample. | |
| - out_features (:obj:`int`): The size of each output sample. | |
| """ | |
| super(TorchBilinearCustomized, self).__init__() | |
| self.in1_features = in1_features | |
| self.in2_features = in2_features | |
| self.out_features = out_features | |
| self.weight = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features)) | |
| self.bias = nn.Parameter(torch.Tensor(out_features)) | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| """ | |
| Overview: | |
| Initialize the parameters of the Bilinear layer. | |
| """ | |
| bound = 1 / math.sqrt(self.in1_features) | |
| nn.init.uniform_(self.weight, -bound, bound) | |
| nn.init.uniform_(self.bias, -bound, bound) | |
| def forward(self, x, z): | |
| """ | |
| Overview: | |
| Compute the bilinear function. | |
| Arguments: | |
| - x (:obj:`torch.Tensor`): The first input tensor. | |
| - z (:obj:`torch.Tensor`): The second input tensor. | |
| """ | |
| # Using torch.einsum for the bilinear operation | |
| out = torch.einsum('bi,oij,bj->bo', x, self.weight, z) + self.bias | |
| return out.squeeze(-1) | |
| """ | |
| Overview: | |
| Implementation of the Bilinear layer as in PyTorch: | |
| https://pytorch.org/docs/stable/generated/torch.nn.Bilinear.html#torch.nn.Bilinear | |
| Arguments: | |
| - in1_features (:obj:`int`): The size of each first input sample. | |
| - in2_features (:obj:`int`): The size of each second input sample. | |
| - out_features (:obj:`int`): The size of each output sample. | |
| - bias (:obj:`bool`): If set to False, the layer will not learn an additive bias. Default: ``True``. | |
| """ | |
| TorchBilinear = nn.Bilinear | |
| class FiLM(nn.Module): | |
| """ | |
| Overview: | |
| Feature-wise Linear Modulation (FiLM) Layer. | |
| This layer applies feature-wise affine transformation based on context. | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def __init__(self, feature_dim: int, context_dim: int): | |
| """ | |
| Overview: | |
| Initialize the FiLM layer. | |
| Arguments: | |
| - feature_dim (:obj:`int`). The dimension of the input feature vector. | |
| - context_dim (:obj:`int`). The dimension of the input context vector. | |
| """ | |
| super(FiLM, self).__init__() | |
| # Define the fully connected layer for context | |
| # The output dimension is twice the feature dimension for gamma and beta | |
| self.context_layer = nn.Linear(context_dim, 2 * feature_dim) | |
| def forward(self, feature: torch.Tensor, context: torch.Tensor): | |
| """ | |
| Overview: | |
| Forward propagation. | |
| Arguments: | |
| - feature (:obj:`torch.Tensor`). The input feature, shape (batch_size, feature_dim). | |
| - context (:obj:`torch.Tensor`). The input context, shape (batch_size, context_dim). | |
| Returns: | |
| - conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim). | |
| """ | |
| # Pass context through the fully connected layer | |
| out = self.context_layer(context) | |
| # Split the output into two parts: gamma and beta | |
| # The dimension for splitting is 1 (feature dimension) | |
| gamma, beta = torch.split(out, out.shape[1] // 2, dim=1) | |
| # Apply feature-wise affine transformation | |
| conditioned_feature = gamma * feature + beta | |
| return conditioned_feature | |
| class GatingType(enum.Enum): | |
| """ | |
| Overview: | |
| Enum class defining different types of tensor gating and aggregation in modules. | |
| """ | |
| NONE = 'none' | |
| GLOBAL = 'global' | |
| POINTWISE = 'pointwise' | |
| class SumMerge(nn.Module): | |
| """ | |
| Overview: | |
| A PyTorch module that merges a list of tensors by computing their sum. All input tensors must have the same | |
| size. This module can work with any type of tensor (vector, units or visual). | |
| Interfaces: | |
| ``__init__``, ``forward`` | |
| """ | |
| def forward(self, tensors: List[Tensor]) -> Tensor: | |
| """ | |
| Overview: | |
| Forward pass of the SumMerge module, which sums the input tensors. | |
| Arguments: | |
| - tensors (:obj:`List[Tensor]`): List of input tensors to be summed. All tensors must have the same size. | |
| Returns: | |
| - summed (:obj:`Tensor`): Tensor resulting from the sum of all input tensors. | |
| """ | |
| # stack the tensors along the first dimension | |
| stacked = torch.stack(tensors, dim=0) | |
| # compute the sum along the first dimension | |
| summed = torch.sum(stacked, dim=0) | |
| # summed = sum(tensors) | |
| return summed | |
| class VectorMerge(nn.Module): | |
| """ | |
| Overview: | |
| Merges multiple vector streams. Streams are first transformed through layer normalization, relu, and linear | |
| layers, then summed. They don't need to have the same size. Gating can also be used before the sum. | |
| Interfaces: | |
| ``__init__``, ``encode``, ``_compute_gate``, ``forward`` | |
| .. note:: | |
| For more details about the gating types, please refer to the GatingType enum class. | |
| """ | |
| def __init__( | |
| self, | |
| input_sizes: Dict[str, int], | |
| output_size: int, | |
| gating_type: GatingType = GatingType.NONE, | |
| use_layer_norm: bool = True, | |
| ): | |
| """ | |
| Overview: | |
| Initialize the `VectorMerge` module. | |
| Arguments: | |
| - input_sizes (:obj:`Dict[str, int]`): A dictionary mapping input names to their sizes. \ | |
| The size is a single integer for 1D inputs, or `None` for 0D inputs. \ | |
| If an input size is `None`, we assume it's `()`. | |
| - output_size (:obj:`int`): The size of the output vector. | |
| - gating_type (:obj:`GatingType`): The type of gating mechanism to use. Default is `GatingType.NONE`. | |
| - use_layer_norm (:obj:`bool`): Whether to use layer normalization. Default is `True`. | |
| """ | |
| super().__init__() | |
| self._input_sizes = OrderedDict(input_sizes) | |
| self._output_size = output_size | |
| self._gating_type = gating_type | |
| self._use_layer_norm = use_layer_norm | |
| if self._use_layer_norm: | |
| self._layer_norms = nn.ModuleDict() | |
| else: | |
| self._layer_norms = None | |
| self._linears = nn.ModuleDict() | |
| for name, size in self._input_sizes.items(): | |
| linear_input_size = size if size > 0 else 1 | |
| if self._use_layer_norm: | |
| self._layer_norms[name] = nn.LayerNorm(linear_input_size) | |
| self._linears[name] = nn.Linear(linear_input_size, self._output_size) | |
| self._gating_linears = nn.ModuleDict() | |
| if self._gating_type is GatingType.GLOBAL: | |
| self.gate_size = 1 | |
| elif self._gating_type is GatingType.POINTWISE: | |
| self.gate_size = self._output_size | |
| elif self._gating_type is GatingType.NONE: | |
| self._gating_linears = None | |
| else: | |
| raise ValueError(f'Gating type {self._gating_type} is not supported') | |
| if self._gating_linears is not None: | |
| if len(self._input_sizes) == 2: | |
| # more efficient than the general version below | |
| for name, size in self._input_sizes.items(): | |
| gate_input_size = size if size > 0 else 1 | |
| gating_layer = nn.Linear(gate_input_size, self.gate_size) | |
| torch.nn.init.normal_(gating_layer.weight, std=0.005) | |
| torch.nn.init.constant_(gating_layer.bias, 0.0) | |
| self._gating_linears[name] = gating_layer | |
| else: | |
| for name, size in self._input_sizes.items(): | |
| gate_input_size = size if size > 0 else 1 | |
| gating_layer = nn.Linear(gate_input_size, len(self._input_sizes) * self.gate_size) | |
| torch.nn.init.normal_(gating_layer.weight, std=0.005) | |
| torch.nn.init.constant_(gating_layer.bias, 0.0) | |
| self._gating_linears[name] = gating_layer | |
| def encode(self, inputs: Dict[str, Tensor]) -> Tuple[List[Tensor], List[Tensor]]: | |
| """ | |
| Overview: | |
| Encode the input tensors using layer normalization, relu, and linear transformations. | |
| Arguments: | |
| - inputs (:obj:`Dict[str, Tensor]`): The input tensors. | |
| Returns: | |
| - gates (:obj:`List[Tensor]`): The gate tensors after transformations. | |
| - outputs (:obj:`List[Tensor]`): The output tensors after transformations. | |
| """ | |
| gates, outputs = [], [] | |
| for name, size in self._input_sizes.items(): | |
| feature = inputs[name] | |
| if size <= 0 and feature.dim() == 1: | |
| feature = feature.unsqueeze(-1) | |
| feature = feature.to(torch.float32) | |
| if self._use_layer_norm and name in self._layer_norms: | |
| feature = self._layer_norms[name](feature) | |
| feature = F.relu(feature) | |
| gates.append(feature) | |
| outputs.append(self._linears[name](feature)) | |
| return gates, outputs | |
| def _compute_gate( | |
| self, | |
| init_gate: List[Tensor], | |
| ) -> List[Tensor]: | |
| """ | |
| Overview: | |
| Compute the gate values based on the initial gate values. | |
| Arguments: | |
| - init_gate (:obj:`List[Tensor]`): The initial gate values. | |
| Returns: | |
| - gate (:obj:`List[Tensor]`): The computed gate values. | |
| """ | |
| if len(self._input_sizes) == 2: | |
| gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)] | |
| gate = sum(gate) | |
| sigmoid = torch.sigmoid(gate) | |
| gate = [sigmoid, 1.0 - sigmoid] | |
| else: | |
| gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)] | |
| gate = sum(gate) | |
| gate = gate.reshape([-1, len(self._input_sizes), self.gate_size]) | |
| gate = F.softmax(gate, dim=1) | |
| assert gate.shape[1] == len(self._input_sizes) | |
| gate = [gate[:, i] for i in range(len(self._input_sizes))] | |
| return gate | |
| def forward(self, inputs: Dict[str, Tensor]) -> Tensor: | |
| """ | |
| Overview: | |
| Forward pass through the VectorMerge module. | |
| Arguments: | |
| - inputs (:obj:`Dict[str, Tensor]`): The input tensors. | |
| Returns: | |
| - output (:obj:`Tensor`): The output tensor after passing through the module. | |
| """ | |
| gates, outputs = self.encode(inputs) | |
| if len(outputs) == 1: | |
| # Special case of 1-D inputs that do not need any gating. | |
| output = outputs[0] | |
| elif self._gating_type is GatingType.NONE: | |
| output = sum(outputs) | |
| else: | |
| gate = self._compute_gate(gates) | |
| data = [g * d for g, d in zip(gate, outputs)] | |
| output = sum(data) | |
| return output | |