TokenFormer: Rethinking Transformer Scaling with Tokenized Model Parameters
174 points
2 months ago
| 12 comments
| arxiv.org
| HN
cs702
2 months ago
[-]
The authors factorize every weight matrix with an attention mechanism:

  weight = attention(token_query, weight_keys, weight_values).
In other words, they query weight_keys to fetch the weight_values, and mix them to compute each weight on the spot.

Increasing model size becomes a matter of adding more weight_keys and weight_values, and incrementally training them.

Simple, clever, and it seems to work well. Beautiful.

reply
szcs
2 months ago
[-]
There is a particularly nice geometric interpretation of attention I just realised recently in a flash of enlightenment, best explained with an interactive Desmos plot (black dot is draggable):

https://www.desmos.com/calculator/3rtqsyapxo

The above assumes the columns of K are normalised but bear with me. K and V together form a vector database. V are the payloads, each row containing a vector of data. K describes the position of these points in space, on the surface of a hypershpere. The query vector describes the query into the database: the vector direction describes the point in space that's being queried, the vector magnitude describes the radius of the query. The result is the weighted average of vectors from V, weighted by their distance from the query vector scaled by the query radius (which has a smooth Gaussian falloff). A recent paper from Nvidia I recommend, which derives a significant speedup by normalising vectors to a hypershpere: https://arxiv.org/abs/2410.01131v1

reply
jdthedisciple
2 months ago
[-]
It looks fascinating, but i don't understand it. I'm haven't gone yet deeply into the theory of attention networks.

Can you explain the desmos plot in simple terms?

reply
szcs
2 months ago
[-]
Attention is a 3 matrix product, s(QK)V where s is softmax. Each matrix has as many rows (Q and V) or columns (K) as many tokens you have in your context. The plot looks at the processing of a single row of Q (predicting a single token from previous ones) called q. q is a 2 element vector and is visualised as the draggable dot (imagine a line from the origin to the dot). The K matrix is shown as green dots, each previous token in the context window is represented as a separate dot. The distance of a blue dot from a corresponding green dot represents how much information from that token gets mixed into the output of the query. The green dots form a hypersphere, a 1D manifold in 2D space. In a real network it would be more like e.g. a 127D manifold in 128D space but the analogy works there as well. You can see how the query gathers information stored on the surface of the manifold by specifying a region and volume of space specified through q's direction and magnitude respectively.
reply
jdthedisciple
2 months ago
[-]
oh wow that makes more sense now!

and what are the orange dots? sorry if I missed that

reply
szcs
2 months ago
[-]
That's just the same distribution laid out along a line instead of a circle
reply
liuliu
2 months ago
[-]
Yeah, I believe this intuition first introduced by the Neural Turing Machine line-of-work and later simplified into AIAYN paper (NTM maintains "external memory" a.k.a. weight_keys, weight_values here).

Disclaimer: these are from my memory, which can be wrong entirely.

reply
anon291
2 months ago
[-]
I believe there have been studies showing that the attention mechanism allows estimation of gradients for one-shot learning (i.e, based on what you tell the model you want in the input, it will use attention to 'update' the weights of the linear layers to 'learn' new information). This seems to be taking that one step further and just using attention for the weight estimations itself. The key insight here is that by adding more tokens to the weight estimation calculation, you can get more degrees of freedom.

Total aside, but imagining how many levels of functions are present in the calculation of each activation here, and thinking about how regular old differentiation and gradient descent actually work to train these nested parameters, is truly amazing, in my opinion.

reply
cs702
2 months ago
[-]
Yeah. This thing is "assembling a different transformer" on the spot for each token.

If one thinks about it for more than a moment, it's kind of incredible that it works.

reply
0-_-0
2 months ago
[-]
I think the same about regular neutral networks
reply
amelius
2 months ago
[-]
So, is this the "neural net" way of using indirection? Like, instead of writing f(t), you now use f(W[t]), where W is some table? And then you use a dot product and write f(W.t) just because you can (or because there's no other way to implement indirection in a neural net)?
reply
valine
2 months ago
[-]
I would like to see a comparison for the inference time compute between a regular transformer and this. I’m assuming token/s is lower since you need to compute the weights of the model for each token prior to the actual attention calculations for the sequence position.
reply
logicchains
2 months ago
[-]
Isn't that figure 5 in the paper? It's for training not inference, but presumably if training is faster then inference would be too. Because they don't increase the dimension of the text tokens when scaling up, which reduces the compute needed for attention. But potentially limits how well the text token attention can keep track of things, because it's got less space for passing things along.
reply
paraschopra
2 months ago
[-]
Why would it be higher? You can keep KV cache precomputed like before.
reply
davesque
2 months ago
[-]
Seems like a big deal. I feel like this could enable a new level of modularity and compatibility between publicly available weight sets, assuming they use similar channel dimensions. Maybe it also provides a nice formalism for thinking about fine tuning, where you could adopt certain heuristics for adding/removing key-value pairs from the Pattention layers.

One interesting thing to note: sounds like model scaling happens on the fly by adding key-value pairs as rows in the K and V matrices on the Pattention layer. That suggests that weights represented by tokens in the first rows may be more important than weights in later rows. There may be a lot you could do with that ordering of weights in terms of pruning and such.

reply
valine
2 months ago
[-]
Unless I’m reading it wrong I don’t think rows matter. Attention doesn’t take into account sequence position natively, that’s why positional encodings exist.
reply
davesque
2 months ago
[-]
I'm talking about the rows in the new K and V matrices introduced by the paper, not rows in the input sequence. The ordering of rows in the new K and V matrices does matter in the sense that rows that appear further down were added later in the training process to add new parameter tokens during scaling. So those newer parameters may represent knowledge that is less fundamental and more about fine tuning on the training set.
reply
paraschopra
2 months ago
[-]
But after adding new rows, I think entire network is retrained.
reply
goldenshale
2 months ago
[-]
This is a great idea. Being able to dynamically scale up model sizes as datasets and use cases expand without needing to retrain from scratch could enable a Cambrian explosion of interesting stuff building on top of a Llama type model trained in this way.
reply
eric15342335
2 months ago
[-]
I am a university year 2 student learning about basic mathematics and statistics related to neural networks. One thing that shocks me is that there isn't an "incremental" solution for building larger (more parameters) AI models (like GPT-4) despite having one in a smaller size e.g. GPT-3.5 (I saw the term "incremental (compiling)" nearly everywhere in the software engineering industry). I am curious how is this not possible theortically?
reply
mynegation
2 months ago
[-]
It is possible, just not practical in many cases. For incremental computations you should be able to either reverse the computation or store the inputs _and_ intermediate results. And you have to repeat some non trivial share of computations anyway, possibly all of it. For AI training this is prohibitively expensive and it is simpler to train from scratch. Not saying it is impossible but demand so far is not there.
reply
c0g
2 months ago
[-]
reply
logicchains
2 months ago
[-]
Seems this would naturally translate into a mixture of experts by using a "hard" attention function so that only a fixed amount of weight tokens get included in the calculation.
reply
ml_thoughts
2 months ago
[-]
This seems closely related to the "Mixtral" approach of a mixture-of-experts transformer [1]... I'm not claiming the approach is not original, it just helped me understand what was going on.

Consider a case of two "experts" or two "value parameter tokens."

The mixture of experts has a "router" network that provides a weight to each expert (through a softmax) conditional on an input. The output is a (sparse) weighted sum of the outputs of the experts.

The TokenFormer has an "attention" layer combines the token and a key value to provide a weight to each "value parameter" token. A(B+C) = AB + AC definitionally, so this is like applying a weighted sum of distinct transformations.

I think the differences are: a) where the non-linearity hits (the above description doesn't consider an activation function), b) this attention softmax is not (necessarily) sparse, c) that "mixtral" networks only replace the feed-forward components of the layer, and d) that extending a "mixtral" approach would require re-training the "router" layers.

It seems like (d) is maybe the nicest feature here... my intuition would think (a) doesn't matter much, (b) is debatable (how close a sparse-MoE can approximate a dense-MoE), (c) has probably been tried (guessing the ffwd limitation was just "more-bang-for-buck-given-parameters" not an oversight)...

... I wonder, though, if there might be diminishing returns here (I believe that Mixture-of-Experts tends to struggle with imbalanced "winner-take-all" dynamics, since "early" winners get more gradient signal to improve their weights) and how different this would have been from going from 3x7B to a 8x7B to a 24x7B training approach (with a "retrain routing networks" step).

[1] https://arxiv.org/abs/2401.04088

reply
mentalically
2 months ago
[-]
Eventually people will figure out how to nest neural networks in the nodes and edges of an arbitrary graph.
reply
davesque
2 months ago
[-]
Seems like a lot of existing models could be converted to this token parameter representation.
reply
sapphire42
2 months ago
[-]
As someone who has worked in this space, this paper is unfortunately total BS.

Their claimed theoretical advancement is as follows. If you want to transform an input vector X to another vector Y of different dimension, "normal" people suggest to use a linear projection: create an appropriately sized matrix W and simply multiply it by your input:

Given X ∈ d_in and W ∈ d_in × d_out, then Y ∈ d_out = X @ W.

In the attention layer, where the input X is converted into queries Q, keys K, and values V, this is the simple strategy employed: Q = X @ W_q, K = X @ W_k, V = X @ W_v, and it has shown itself to be effective.

This is too simple for the authors of this paper. They propose another approach. Instead of converting directly to the desired dimension, we will increase computation by creating an intermediate dimension, and introducing a non-linearity between them.

Given X ∈ d_in, and W_1 ∈ d_in × d_tmp, and W_2 ∈ d_tmp × d_out, then Y ∈ d_out = f(X @ W_1) @ W_2.

Here, f can be any non-linearity. The authors choose softmax; it allows them to claim a superficial resemblance to attention. Later in the paper, they reveal it is not actually softmax, but a modified version to avoid gradient vanishing (softmax is not a very good general-purpose non-linearity).

So, they replace all projections in the attention layer with this new strategy. So Q = f(X @ W_q1) @ W_q2. And K = f(X @ W_k1) @ W_k2. And V = f(X @ W_k3).

The problem with this is not theoretical: this does increase the model's expressiveness and computational power. It is practical: we are adding parameters where we need them the least, in the attention layer. It is generally understood that LLMs do not need extra parameters in the attention layer. Actually, advancements like Grouped-Query Attention hinge on the idea that you can halve or even fourth the number of parameters in the attention layer without harming performance. The experience of the LLM community so far suggests that the authors' idea of adding even more parameters to the self-attention layer should degrade their models' performance while adding no tangible gain.

The authors' numbers say otherwise. But it is hard to trust their numbers. When training a Transformer to compare against they replicate the original GPT-2 proposed in 2019. In doing so they ignore years of architectural improvements, such as rotary positional embeddings, SwiGLU, and RMSNorm that have culminated in Transformer++, the strong recipe which is what Meta's Llama series uses. We've seen this time after time in the various "Transformer killers" that used to be popular about a year ago. A researcher would think up some novel variant of linear attention, furiously test it against a weak GPT-2 baseline, find it blew it out of the water, and declare victory. Somehow, these never caught on, because when tested against a newer baseline these models weren't actually that great. The authors are doing the same thing here.

In their tables they also include comparisons to other models. Actually, they exclusively select the EleutherAI suites: GPT-Neo, OPT, and Pythia. These models were not trained with any modern architectural improvements except rotary embedding (which EleutherAI invented), and so predictably TokenFormer crushes them. On the last page of the appendix the authors have included a full table with some more fair comparisons. Their TokenFormer-150M variant achieves a Pile ppl of 10.45 against Mamba-130M's 10.54. In the intermediate weight class, TokenFormer-450M matches Mamba-370M's 8.28 Pile ppl despite having 21% more parameters. And in the largest size, TokenFormer-1.5B loses to Mamba-1.4B, 6.91 to 6.80 ppl.

Overall, the architectural tweak proposed in this paper is impractical, and the few fair comparisons they include are unimpressive. TokenFormer is another in a long line of Transformer-killers that have nice graphs of cherry-picked data, and will similarly fade into obscurity.

reply
riley_s8
1 month ago
[-]
totally agree. It doesn't make any sense to use linear(softmax(linear(x))) to replace linear(x) while claiming to be more explainable and more scalable.
reply
logicchains
2 months ago
[-]
I feel like you fundamentally misunderstood the paper. It's not only the attention weights; the weights in the MLP layer that follows each attention layer are also generated based on the methodology they describe.
reply
sapphire42
2 months ago
[-]
Yes, and this results in the MLP layer being functionally unchanged. In the vanilla GPT-2 Transformer, the MLP layer is defined as a 4x up-projection, then a non-linearity, followed by a 4x down-projection. This can be understood as a specific case of their method, as they describe here:

> The number of key-value parameter pairs in both the query-key-value and output projections corresponds directly to the hidden dimension. In contrast, the FFN module utilizes four times the number of parameter pairs relative to the hidden size.

Here is the original FFN as described in GPT-2:

y = GELU(x @ W_u) @ W_d

And here is their FFN, when understood as a special case of their "Attention":

y = modified_softmax(x @ W_k) @ W_v

You can name the matrices whatever you want, but the grand enhancement that the authors make to the FFN is just replacing the GELU with a different non-linearity. Shazeer already conducted extensive empirical tests of different non-linearities for the FFN layer in 2020. Among the best were SwiGLU, which is used in Llama today. Unsurprisingly, a modified softmax did not make the cut.

Again, if the changes in this paper were truly a step forward instead of a mindless scrambling of architecture in an effort to achieve something publishable, it would show in the results. Instead, as you can see in their appendix, TokenFormer is on-par or loses in fair comparisons to other models.

reply
boroboro4
2 months ago
[-]
Isn't their main claim is ability to gradually increasing weights number and saving on total training costs, rather than just expressiveness / efficiency of the architecture?
reply
sapphire42
2 months ago
[-]
They do actually make several claims as to the efficiency of the architecture compared to the Transformer, as you can see by the many graphs throughout the document. Their claim that their architecture is the only one that allows for gradually increasing the number of weights is a prominent one too, though, so I'll explain why I don't find that claim credible.

The idea of gradually increasing the size of a Transformer to save on training costs is not a novel one, and researchers have explored ideas to this effect almost since the Transformer's inception. There are many ways to do it. We can start with a small number of layers, and then add in more initialized to the identity. We can keep the number of layers constant and start with a small, then increase the width throughout training, initializing the extra weights to zero. We can reformulate all weight matrices as LoRAs and start with a tiny rank, then slowly increase the rank until we reach a full-rank equivalent. Or we can use two or three of these strategies and mix them any way we want.

The performance of the resultant model is entirely dependent on what strategies you use, and how you mix them: whether you choose to increase width, depth, or rank all at once, one at a time, or somewhere in-between, and whether you increase those values linearly, exponentially, or by some new function you just thought of. Because there are so many ways to gradually increase the size of a Transformer, when you think of a new way, you've got to pick a strong baseline to compare against.

The authors choose the baseline Net2Net (2015). The paper, written two years before the invention of the Transformer, regrettably does not include pre-trained Transformer results for the authors to compare against. So, the authors train their own Net2Net model, and provide a couple nice graphs where the TokenFormer loss curve is under the Net2Net Transformer's for the entirety of training in Figure 6 and Figure 7. They provide no details of the training setup that produced these graphs: the model size, layer count, and width are all missing, as well as basic hyperparameters like the learning rate and batch size. They train on enwik8 (100MB) and seem to repeat data: near the end the TokenFormer reaches sub-0.5 perplexity levels, an impossible result for English text with reasonable entropy that a language model has never seen before.

Why choose this strange, home-grown baseline, reliant on a method developed in 2015, to compare against? Why not at least use a method tuned specifically for the Transformer (such as [1](https://arxiv.org/abs/2203.06211), [2](https://arxiv.org/abs/2401.02415), [3](https://arxiv.org/abs/2309.03852), to name a few!) If their progressive scaling method is truly better, it would only benefit from comparison against a strong baseline.

The authors' progressive scaling method is an idea that has been explored many times by other ML researchers. Their method in particular is compared against a weak baseline with no concrete details other than the loss graphs. In my humble opinion, it's merely an effort to shoehorn a claim of novelty into a paper that isn't.

reply
a_wild_dandan
2 months ago
[-]
This could be revolutionary. The PPL/compute graphs are damning. If the Transformer is a function, then the TokenFormer feels like a higher-order function. Perhaps this approach is a natural direction for producing System Two reasoning? There's so much to digest here...
reply