Bilinear

The Bilinear class represents the main building block for our networks. While all operations on this objects are simple to implement in code, the mathematical reasoning isn't. Hence, we this section described this class function by function.

Representation

Any bilinear form can be represented by a matrix, just like any linear operation is represented by a vector. However, we want to perform multiple operations on any given input. Just as a matrix performs multiple linear operations, a third-order tensor can perform multiple bilinear operations.

Representing a third-order tensor requires \(n^3\) parameters/weights, which is often intractable for neural networks. Hence, we represent this tensor using a decomposition.

While there exist many, we opt for the canonical decomposition, which can be seen as an SVD but for three modes. Basically, it has an implicit diagonal third-order tensor and then a matrix at each mode.

Often, an input is up-projected, just like in an ordinary MLP, before performing the square, afterwards it is down-projected again. This has been found to work well in practice and has been hypothesized to be where models store their factual knowledge. Formally, though, it is increasing the tensor rank.

Initialization

This class is often not used directly as a neural network component (even though it could). Instead, there is a separate class that is optimized for speed for that use-case. The main reason is that a bilinear layer needs to handle the biases (and residual) as an extra dimension, which can be detrimental to performance due to how GPUs work. Hence, the ubiquitous nn.Linear is used for training.

There are three common functions to instantiate a bilinear layer. The first, make_bilinear, takes three nn.Linear as argument, which correspond to the modes. It automatically incorporates the biases into the bilinear as the first dimension (since it often turns out to be the most important one).

Functions

This component includes many common operations that one could wish to perform on it. Many operations have two versions, in-place and cloning. Following PyTorch convention, in-place operations are postfixed with an underscore.

Symmetrization

When tracing through or considering a layer in a decomposition. The two input matrices present two differing paths. This can lead to an exponential number of paths when considering multiple layers. Hence, it is advantageous to symmetrize the input matrices. This essentially turns the bilinear layer into a squared activation function, often at some 'cost'.

One way to symmetrize is by doubling the hidden dimension, allowing the information of either side to be combined into a single matrix.

p = 0.5 * torch.cat([self.p, -self.p], dim=1)
lr = torch.cat([self.l + self.r, self.l - self.r], dim=0)
return Bilinear(lr, lr, p)

Working this out, we get the following. We ignore the $p$ and view that as $1$ and $-1$ for simplicity.

\[ (l + r)(l + r) - (l - r)(l - r) = (l^2 + 2lr + r^2) - (l^2 - 2lr + r^2) = 4lr\]

Basically, this setup uses a nifty trick where the purely quadratic terms, not present in the bilinear layer, cancel out. Only the cross-interacting terms, which we want, remain.

TODO: complex symmetrization

Change of basis

At its core, a bilinear layer basically consists of a square, surrounded by two changes of basis. Commonly, we wish to fold matrices into either side (input/output) as part of our analysis. We define two functions to this end.

  • fold: contracts a matrix into the both input matrices.
  • project: contracts a matrix into the output matrix.

Algebraic operations

Bilinear layers have built-in support for scalar multiplication as well as two kinds of addition.

The first kind of addition, denoted as the + operation, implements a 'natural' addition of bilinear layers.

\[ x' = xAx + xBx \]

The second kind, denoted as ^, is a concatenation which roughly computes the following.

\[ (x' ~ y') = (x ~ y)(A ~ B)(x ~ y) \]

The former is commonly used when working with a sole bilinear layer, the latter is used when the bilinear layer is part of a larger object.

Singular vectors

Many analysis methods used in this repo rely on the use of spectral decomposition. Hence, we often need to compute right- and left singular vectors.

Computing singular vectors is basically multiplying a matrix (or bilinear layer) with its transpose. Let's start with left-singular vectors.