Permutation Equivariance
Actually, check out the 5 minutes video of our work on YouTube would be a good choice! All the gist of our work is there. And here, is mostly about my personal view during the research process, and more mathematical details.
Inspiration
We all know that Transformer is somehow invariant to token permutation. Like in the famous work, Intriguing Properties of Vision Transformers, as shown in the figure below:
they found that ViT is quite robust to patch shuffling. We all know some of the reasons, like:
- Tokens are symmetric to self-attention.
- Transformer can only sense position information through positional encoding.
etc. etc… But all of these explanations are quite intuitive, instead of being rigorous. My advisor, Liyao Xiang, asked me,
“Why? Why is Transformer invariant to token permutation? There must be a mathematical explanation.”
And that’s the beginning of our research, and my transformation.
Invariance? Equivariance!
At the beginning of this project, we were also quite curious about the permutation properties, especially from a mathematical perspective. We called it “invariance” at that time. However, if we looked into it a little bit deeper, we would soon discover the real property was not called invariance, but equivariance.
$$ f(Px) = Pf(x) \tag{equivariance} $$
$$ f(Px) = f(x) \tag{invariance} $$
I, found that, Transformer rigorously satisfies Eq. (equivariance). Actually, I was definitely not the first one to find that. Eq. (equivariance) is kind of a not-so-common common knowledge. So the Transformer is literally equivariant to token permutation, like this:
Wait, how does that even work? Transformer is by all means not linear. How can it satisfy Eq. (equivariance)? Well, let’s dive into the Transformer architecture. The main parts of Transformer are:
- Self-Attention
- linear projection
- softmax
- Feed-Forward
- linear projection
- element-wise activation
- Shortcut
- matrix addition
Except for linear projection, all the other parts are element-wise operations (softmax is not… well, just go on reading, they are the same to permutation). Element-wise operations are permutation equivariant:
$$ (P_1 A P_2) \odot ( P_1 B P_2) = P_1 (A \odot B) P_2 $$
Why? Because they are element-wise! On the left hand-side of the equation, $a_{ij}$ in $A$ and $b_{ij}$ in $B$ are permuted to the same position before being performed the operation. On the right hand-side, $a_{ij}$ and $b_{ij}$ are performed the operation first, then permuted to the same position. So the results are the same.
How about the linear projection? Well, it is linear!
$$ P_1 (AW)= (P_1 A)W $$
Before we proceed, take a little quiz: are you familiar with the $P$ in Eq. (equivariance)? What does it mean? It is basic algebra, so basic that many people would forget about it. If you are not so confident, take a quick algebra recap:
Recap
To permute a matrix $X$ is multiplying it by a permutation matrix $P$. Like:
$$ X = \begin{bmatrix} 1& 2 & 3 & 4\\ 5& 6 & 7 & 8\\ 9& 10 & 11 &12 \end{bmatrix}, P = \begin{bmatrix} 0& 1 & 0 \\ 0& 0 & 1 \\ 1& 0 & 0 \end{bmatrix} $$
The order of $P$ here is $2,3,1$:
$$PX = \begin{bmatrix} 5& 6 & 7 & 8\\ 9& 10 & 11 &12\\ 1& 2 & 3 & 4 \end{bmatrix}$$
You can calculate it by yourself if you don’t believe me.
By the way, left multiplication is the row permutation, and right multiplication is the column permutation.
By the way, the permutation matrix is an orthogonal matrix, which means $P^\top = P^{-1}$.
Here comes the truly interesting part. After I reported my findings to my dear advisor, she asked me,
“What about the backward propagation? What happens in the backward?”
Backward Propagation
So I went back, taught myself some basic matrix calculus, and tried to calculate the total differential of Transformer. That’s crazy. I mean, I still can’t believe the total differential of a huge neural network can be written in one line. But it is. And it is quite simple.
At first, it does not make sense – the learning process is totally neither invariant nor equivariant.
Formulation of Transformer
To explain this, we need to formulate Transformer and its training process. Normally in math we denote a vector as a column vector. But to make it less confusing when we implement it with PyTorch, here we denote a token as a row vector. The Transformer looks like this:
OK, so the input of Transformer is a sequence of tokens, denoted as $Z$, like in ViT-Base, $Z\in\mathbb{R}^{197\times 768}$. When $Z$ enters the Transformer, it meets $W_Q, W_K, W_V$ first, and Let’s say:
$$\begin{align} Q = ZW_Q^\top\end{align}$$ $$\begin{align} K = ZW_K^\top\end{align}$$ $$\begin{align} V = ZW_V^\top\end{align}$$
Then, the self-attention is calculated as:
$$\begin{align} S = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) \end{align}$$ $$\begin{align} A = SV \end{align}$$
We neglect the residual connection and the attention projection for simplicity. The MLP feed-forward is calculated as:
$$\begin{align} A_1 = A W_1^\top \end{align}$$ $$\begin{align} H = \text{ReLU}(A_1) \end{align}$$ $$\begin{align} A_2 = H W_2^\top \end{align}$$
Total Differential
After layers of those, we get the output of Transformer, which, let’s denote as $A_3$. Actually we only need to consider one layer, and the induction will do the rest, so let’s say $A_3 = A_2$. And…let’s just make the downstream head a linear: $O = A_3 W^\top$. So the gradient is:
$$ \begin{aligned} \text{d}l &= \text{tr}\left(\frac{\partial l}{\partial O}^\top \text{d}O\right)\\ &= \text{tr}\left(\frac{\partial l}{\partial O}^\top( \text{d}A_3)W^\top \right) + \text{tr}\left(\frac{\partial l}{\partial O}^\top A_3 \text{d}W^\top\right)\\ \end{aligned} $$
where $l$ is the loss function.
Allow me to direct you to the text book, Matrix Calculus: Derivation and Simple Application HU Pili, if you are unfamiliar with the derivation above. I assure you that after reading it, you will find the derivation of Transformer is quite simple.
That indicates: $$\begin{align} \frac{\partial l}{\partial A_3} = \frac{\partial l}{\partial O}W \end{align}$$ and: $$ \begin{align} \frac{\partial l}{\partial W} = \frac{\partial l}{\partial O}^\top A_3 \end{align}$$ So in the forward propagation, we know the value of $A_3$ and $W$. Once we know $\frac{\partial l}{\partial O}$, we can calculate the gradient of $W$ and $A_3$. As for the value of $\frac{\partial l}{\partial O}$, well, forget about it. PyTorch knows it and that’s enough. To us, it is just a known value.
Permute It! Emmm….
Now you see, if we permute the input $Z$ to $PZ$, the output would be $A_3 = P A_2$. And the gradient would be:
$$ \begin{align} \text{d}l_{(P)} &= \text{tr}\left(\frac{\partial l_{(P)}}{\partial O_{(P)}}^\top( \text{d}PA_2)W^\top \right) + \text{tr}\left(\frac{\partial l_{(P)}}{\partial O_{(P)}}^\top PA_2 \text{d}W^\top\right)\\ \end{align} $$ where we denote all the variables in permuted setting with a subscript $(P)$. Like $$\begin{align} A_{2(P) }= P A_2\end{align}$$
This…just makes no sense.
For nights and nights I stared at the equations, it suddenly hit me, that if we permute $PA_2$ back, things would be different:
$$ \begin{align}A_{3(P)} = P^\top A_{2(P)} = P^\top \cdot P A_2 = A_2 = A_3 \end{align} $$
Now, since $A_3 = A_{3(P)}$, all the following things, loss, and gradient, would be the same. $$ \begin{align}l_{(P)} = l, O_{(P)} = O \end{align}$$ $$ \begin{align}\frac{\partial l_{(P)}}{\partial O_{(P)}} = \frac{\partial l}{\partial O} \end{align}$$ $$ \begin{align}\frac{\partial l_{(P)}}{\partial A_{3(P)}} = \frac{\partial l}{\partial A_3} \end{align}$$
OK, pause here. Before we proceed, I sincerely recommend you to take a look at the relationship between $A_3, A_2, A_{3(P)}, A_{2(P)}$. What we are trying to do here is to find a bridge between the original setting and the permuted setting, especially for the gradients of weights $W_Q, W_K, W_V, W_1, W_2$. To get their, we must pass the entrance – the relationship between $\frac{\partial l_{(P)}}{\partial A_{2(P)}}$ and $\frac{\partial l}{\partial A_2}$.
Now let’s focus on the relationship between $\frac{\partial l_{(P)}}{\partial A_{2(P)}}$ and $\frac{\partial l}{\partial A_2}$. Since $A_{3(P)} = P^\top A_{2(P)}$,
$$ \begin{align} \text{d}l_{(P)} &= \text{tr}\left(\frac{\partial l_{(P)}}{\partial A_{3(P)}}^\top P^\top \text{d}A_{2(P)}\right) \\ &= \text{tr}\left((P\frac{\partial l_{(P)}}{\partial A_{3(P)}})^\top \text{d}A_{2(P)}\right) \\ \end{align} $$ and thus: $$ \begin{align}\frac{\partial l_{(P)}}{\partial A_{2(P)}} = P\frac{\partial l_{(P)}}{\partial A_{3(P)}} = P\frac{\partial l}{\partial A_3} = P\frac{\partial l}{\partial A_2} \end{align}$$
You see what’s happening here? The gradient of $A_2$ in the permuted setting, equals to the gradient of $A_2$ in the original setting, permuted!
Wait, we want to know the gradient of weights, right? Eq. (10) tells us that: $$ \frac{\partial l}{\partial W_2} = \frac{\partial l}{\partial A_2}^\top H $$ so $$\begin{align} \frac{\partial l_{(P)}}{\partial W_{2(P)}} &= \frac{\partial l_{(P)}}{\partial A_{2(P)}}^\top H_{(P)}\\ &= \frac{\partial l}{\partial A_2}^\top P^\top PH\\ &= \frac{\partial l}{\partial A_2}^\top H\\ &= \frac{\partial l}{\partial W_2} \end{align}$$
OK, we can pop the champagne now.
Or may be later, otherwise we would be drunk before we finish the paper. There are a lot of champagne moments later. For me, there were even more. After I found this, I immediately turned on my computer, and fine-tuned a ViT model on CIFAR-10, permuted of course. And it worked! All the properties above were verified. Ah, nothing is more satisfying than seeing the theory works in practice.
PZ… ZP?
Even though I was quite excited about the result, and the scientifically satisfying moment is surely a transformative experience, this is not enough for a paper.
Then one night I was in the bed, Q and K and V and P were all floating in my mind. Suddenly, floating P met floating Z, not like “PZ”, but like “ZP”.
Holy, how could I never think of this? Left multiplication is the row permutation, and right multiplication is the column permutation.
Nah, that makes no sense: $$ ZPW \neq ZW$$
Later, another P floated by: $$ ZP \cdot P^\top W = ZW$$
Wait, if I add one more P…
$$ ZP \cdot P^\top W_1 P = ZW_1P, (ZW_1)P \cdot P^\top W_2 P = (ZW_1)W_2$$
The permutation passes through the linear layer and on and on!
Wait wait wait, what about the backward? Emmm… $$ \begin{align} \text{d}l&=\text{tr}(\frac{\partial l}{\partial A_{3(P)}}^{\top} P_R^{\top} \text{d}( A_{2(P)}) P_C^{\top})\\ &=\text{tr}( P_C^{\top}\frac{\partial l}{\partial A_{3(P)}}^{\top} P_R^{\top} \text{d} A_{2(P)})\\ &=\text{tr}(( P_R \frac{\partial l}{\partial A_{3(P)}} P_C)^{\top}\text{d} A_{2(P)}), \end{align} $$ that is: $$ \begin{equation} \frac{\partial l}{\partial A_{2(P)}}= P_R\frac{\partial l}{\partial A_{3(P)}} P_C = P_R\frac{\partial l}{\partial A_{2}} P_C \end{equation} $$ where $P_R$ and $P_C$ are the row-wise and column-wise permutation matrices.
Wait, we are close to the gradient of weights:
$$ \begin{align} \frac{\partial l}{\partial W_{2(P)}}&=\frac{\partial l}{\partial A_{2(P)}}^{\top} H_{(P)}\\ &= P_C^{\top}\frac{\partial l}{\partial A_{2}}^{\top} P_R^{\top}\cdot P_R H P_C \\ &= P_C^{\top}\frac{\partial l}{\partial A_{2}}^{\top} H P_C \\ &= P_C^{\top} \frac{\partial l}{\partial W_{2}} P_C , \end{align} $$
Holy, that’s it!
Well… I made some simplifications here. Like, if you look into it a little bit deeper, you will find that $H$ and $W_2$ are not square matrices, so the multiplication does not work like that. And if you are interested in the details, like some serious proof, I would recommend you to read the appendix of our paper.
Is this true in practice?
There was a dramatic story when I verifying all the properties with experiments. Any way, this can perfectly explain the practical results.
Say you have a randomly initialized model, and you train it with permutation-unpermutation, so according to theories above, the result would be like $P^\top W P$. So if you permute back all the weights, manually, the model would function the same as the original model.
But does the equation really hold? $$\begin{align} W_{(P)} = P^\top W P \end{align}$$ It surely can explain the results, but is it true? If we permute the initial weights, and somehow control the seed and the damn random cuda cusolver, dose the equation hold?
The math says yes, but who knows? And most importantly, who cares? I have another explanation for permutation equivariance, which I would not share here because it is boring, and it can also holds. I think, no, the math says, it depends on the initial weights. If initial weights are permuted like $W_{0(P)} = P^\top W_0 P$, then the math says, the equation holds. If the initializations cater to another explanation, then the math turns to it. But no matter what, the theories above always guide practice.
Is this really useful?
Indeed the properties above are quite interesting, but many would ask, what for? Well, I don’t know.
Maybe we can guide the design of permutation invariant models, like, according to the math above, the invariance of Set Transformer only holds for forward pass, not for backward. The order of the set still has an impact on the training.
Maybe we can use that for a encryption, as mentioned in the paper.
Or maybe, we can just tell more people about this paper and see if they can find some interesting ideas.