Orthogonalisation

The singular value decomposition splits any matrix into two orthogonal matrices and a diagonal matrix. Orthogonal matrices satisfy almost all criteria of 'nice' matrices, making them easy to work with. The same holds for diagonal matrices, which can simply be seen as a scaling operation.

This idea can be extended to full (acyclic) tensor networks, transforming each part into a diagonal or orthogonal tensor. Tensor networks that admit this structure are often called 'canonical'.

Why orthogonalise?

Within circuit analysis, a common operation is to trace between two points and back. For instance, tracing between two tokens through an attention head, measuring how much these tokens attend to each other. This generally involves multiplying a subnetwork with its (gently modified) transpose. In an orthogonalised network, the multiplication with its transpose will cancel out, greatly simplifiying these computations.

Furthermore, computing the norm within such networks is trivial (since that relies on multiplication with its transpose as well). This is quite handy for interpretability since its often useful how much of the variance is explained by some subnetwork.

Algorithm overview

The orthogonalisation algorithm is recursive in nature, consisting of two steps.

  1. Decompose a given tensor into orthogonal and non-orthogonal parts.
  2. Contract the non-orthogonal part into the neighbouring tensor.

Rinse and repeat, starting from the inputs toward the output.

The algorithm

This section conveys the intuition of the algorithm using a naive approach. The 'tricks' to make it efficient is explained afterward. Furthermore, we focus on stacked bilinear layers (deep MLPs). Adaptations to other architectures are explained later.

A bilinear layer (left) can be flattened into a matrix (middle). This is generally denoted as combining two wires into a thick wire.

Since \(B\) is now a wide matrix, we can perform SVD (right). The result is a wide orthogonal matrix \(V\) and a small non-orthogonal part \(US\).

If we multiply the bilinear layer with its transpose (left) and compute the same SVD (middle), the wide \(V\) tensor vanishes (right). Hence, replacing the bilinear layer with \(V\) and pushing up the \(US\) matrix completes the local orthogonalisation.

More interestingly, this works hierarchically. If the first layer is orthogonalised, performing the same operation on the second layer will orthogonalise the whole.

Doing this on the whole network yields a stack of orthogonal tensors and a non-orthogonal matrix at the network output.

Direct orthogonalisation