weight = attention(token_query, weight_keys, weight_values).
In other words, they query weight_keys to fetch the weight_values, and mix them to compute each weight on the spot.Increasing model size becomes a matter of adding more weight_keys and weight_values, and incrementally training them.
Simple, clever, and it seems to work well. Beautiful.
https://www.desmos.com/calculator/3rtqsyapxo
The above assumes the columns of K are normalised but bear with me. K and V together form a vector database. V are the payloads, each row containing a vector of data. K describes the position of these points in space, on the surface of a hypershpere. The query vector describes the query into the database: the vector direction describes the point in space that's being queried, the vector magnitude describes the radius of the query. The result is the weighted average of vectors from V, weighted by their distance from the query vector scaled by the query radius (which has a smooth Gaussian falloff). A recent paper from Nvidia I recommend, which derives a significant speedup by normalising vectors to a hypershpere: https://arxiv.org/abs/2410.01131v1
Can you explain the desmos plot in simple terms?
and what are the orange dots? sorry if I missed that
Disclaimer: these are from my memory, which can be wrong entirely.
Total aside, but imagining how many levels of functions are present in the calculation of each activation here, and thinking about how regular old differentiation and gradient descent actually work to train these nested parameters, is truly amazing, in my opinion.
If one thinks about it for more than a moment, it's kind of incredible that it works.
One interesting thing to note: sounds like model scaling happens on the fly by adding key-value pairs as rows in the K and V matrices on the Pattention layer. That suggests that weights represented by tokens in the first rows may be more important than weights in later rows. There may be a lot you could do with that ordering of weights in terms of pruning and such.
Consider a case of two "experts" or two "value parameter tokens."
The mixture of experts has a "router" network that provides a weight to each expert (through a softmax) conditional on an input. The output is a (sparse) weighted sum of the outputs of the experts.
The TokenFormer has an "attention" layer combines the token and a key value to provide a weight to each "value parameter" token. A(B+C) = AB + AC definitionally, so this is like applying a weighted sum of distinct transformations.
I think the differences are: a) where the non-linearity hits (the above description doesn't consider an activation function), b) this attention softmax is not (necessarily) sparse, c) that "mixtral" networks only replace the feed-forward components of the layer, and d) that extending a "mixtral" approach would require re-training the "router" layers.
It seems like (d) is maybe the nicest feature here... my intuition would think (a) doesn't matter much, (b) is debatable (how close a sparse-MoE can approximate a dense-MoE), (c) has probably been tried (guessing the ffwd limitation was just "more-bang-for-buck-given-parameters" not an oversight)...
... I wonder, though, if there might be diminishing returns here (I believe that Mixture-of-Experts tends to struggle with imbalanced "winner-take-all" dynamics, since "early" winners get more gradient signal to improve their weights) and how different this would have been from going from 3x7B to a 8x7B to a 24x7B training approach (with a "retrain routing networks" step).
Their claimed theoretical advancement is as follows. If you want to transform an input vector X to another vector Y of different dimension, "normal" people suggest to use a linear projection: create an appropriately sized matrix W and simply multiply it by your input:
Given X ∈ d_in and W ∈ d_in × d_out, then Y ∈ d_out = X @ W.
In the attention layer, where the input X is converted into queries Q, keys K, and values V, this is the simple strategy employed: Q = X @ W_q, K = X @ W_k, V = X @ W_v, and it has shown itself to be effective.
This is too simple for the authors of this paper. They propose another approach. Instead of converting directly to the desired dimension, we will increase computation by creating an intermediate dimension, and introducing a non-linearity between them.
Given X ∈ d_in, and W_1 ∈ d_in × d_tmp, and W_2 ∈ d_tmp × d_out, then Y ∈ d_out = f(X @ W_1) @ W_2.
Here, f can be any non-linearity. The authors choose softmax; it allows them to claim a superficial resemblance to attention. Later in the paper, they reveal it is not actually softmax, but a modified version to avoid gradient vanishing (softmax is not a very good general-purpose non-linearity).
So, they replace all projections in the attention layer with this new strategy. So Q = f(X @ W_q1) @ W_q2. And K = f(X @ W_k1) @ W_k2. And V = f(X @ W_k3).
The problem with this is not theoretical: this does increase the model's expressiveness and computational power. It is practical: we are adding parameters where we need them the least, in the attention layer. It is generally understood that LLMs do not need extra parameters in the attention layer. Actually, advancements like Grouped-Query Attention hinge on the idea that you can halve or even fourth the number of parameters in the attention layer without harming performance. The experience of the LLM community so far suggests that the authors' idea of adding even more parameters to the self-attention layer should degrade their models' performance while adding no tangible gain.
The authors' numbers say otherwise. But it is hard to trust their numbers. When training a Transformer to compare against they replicate the original GPT-2 proposed in 2019. In doing so they ignore years of architectural improvements, such as rotary positional embeddings, SwiGLU, and RMSNorm that have culminated in Transformer++, the strong recipe which is what Meta's Llama series uses. We've seen this time after time in the various "Transformer killers" that used to be popular about a year ago. A researcher would think up some novel variant of linear attention, furiously test it against a weak GPT-2 baseline, find it blew it out of the water, and declare victory. Somehow, these never caught on, because when tested against a newer baseline these models weren't actually that great. The authors are doing the same thing here.
In their tables they also include comparisons to other models. Actually, they exclusively select the EleutherAI suites: GPT-Neo, OPT, and Pythia. These models were not trained with any modern architectural improvements except rotary embedding (which EleutherAI invented), and so predictably TokenFormer crushes them. On the last page of the appendix the authors have included a full table with some more fair comparisons. Their TokenFormer-150M variant achieves a Pile ppl of 10.45 against Mamba-130M's 10.54. In the intermediate weight class, TokenFormer-450M matches Mamba-370M's 8.28 Pile ppl despite having 21% more parameters. And in the largest size, TokenFormer-1.5B loses to Mamba-1.4B, 6.91 to 6.80 ppl.
Overall, the architectural tweak proposed in this paper is impractical, and the few fair comparisons they include are unimpressive. TokenFormer is another in a long line of Transformer-killers that have nice graphs of cherry-picked data, and will similarly fade into obscurity.
> The number of key-value parameter pairs in both the query-key-value and output projections corresponds directly to the hidden dimension. In contrast, the FFN module utilizes four times the number of parameter pairs relative to the hidden size.
Here is the original FFN as described in GPT-2:
y = GELU(x @ W_u) @ W_d
And here is their FFN, when understood as a special case of their "Attention":
y = modified_softmax(x @ W_k) @ W_v
You can name the matrices whatever you want, but the grand enhancement that the authors make to the FFN is just replacing the GELU with a different non-linearity. Shazeer already conducted extensive empirical tests of different non-linearities for the FFN layer in 2020. Among the best were SwiGLU, which is used in Llama today. Unsurprisingly, a modified softmax did not make the cut.
Again, if the changes in this paper were truly a step forward instead of a mindless scrambling of architecture in an effort to achieve something publishable, it would show in the results. Instead, as you can see in their appendix, TokenFormer is on-par or loses in fair comparisons to other models.
The idea of gradually increasing the size of a Transformer to save on training costs is not a novel one, and researchers have explored ideas to this effect almost since the Transformer's inception. There are many ways to do it. We can start with a small number of layers, and then add in more initialized to the identity. We can keep the number of layers constant and start with a small, then increase the width throughout training, initializing the extra weights to zero. We can reformulate all weight matrices as LoRAs and start with a tiny rank, then slowly increase the rank until we reach a full-rank equivalent. Or we can use two or three of these strategies and mix them any way we want.
The performance of the resultant model is entirely dependent on what strategies you use, and how you mix them: whether you choose to increase width, depth, or rank all at once, one at a time, or somewhere in-between, and whether you increase those values linearly, exponentially, or by some new function you just thought of. Because there are so many ways to gradually increase the size of a Transformer, when you think of a new way, you've got to pick a strong baseline to compare against.
The authors choose the baseline Net2Net (2015). The paper, written two years before the invention of the Transformer, regrettably does not include pre-trained Transformer results for the authors to compare against. So, the authors train their own Net2Net model, and provide a couple nice graphs where the TokenFormer loss curve is under the Net2Net Transformer's for the entirety of training in Figure 6 and Figure 7. They provide no details of the training setup that produced these graphs: the model size, layer count, and width are all missing, as well as basic hyperparameters like the learning rate and batch size. They train on enwik8 (100MB) and seem to repeat data: near the end the TokenFormer reaches sub-0.5 perplexity levels, an impossible result for English text with reasonable entropy that a language model has never seen before.
Why choose this strange, home-grown baseline, reliant on a method developed in 2015, to compare against? Why not at least use a method tuned specifically for the Transformer (such as [1](https://arxiv.org/abs/2203.06211), [2](https://arxiv.org/abs/2401.02415), [3](https://arxiv.org/abs/2309.03852), to name a few!) If their progressive scaling method is truly better, it would only benefit from comparison against a strong baseline.
The authors' progressive scaling method is an idea that has been explored many times by other ML researchers. Their method in particular is compared against a weak baseline with no concrete details other than the loss graphs. In my humble opinion, it's merely an effort to shoehorn a claim of novelty into a paper that isn't.