MegaTrain: Full Precision Training of 100B+ Parameter LLMs on a Single GPU
193 points
6 hours ago
| 11 comments
| arxiv.org
| HN
internetguy
5 hours ago
[-]
> MegaTrain stores parameters and optimizer states in host memory (CPU memory) and treats GPUs as transient compute engines. For each layer, we stream parameters in and compute gradients out, minimizing persistent device state

This is pretty awesome. The only compute I have at home is an RTX 3080 with 10 GB of VRAM, so I struggle with training larger models (>40M, 50M params). I get OOM errors and have to optimize a lot.

I have a lot more CPU RAM in my PC, and this would likely increase the size of models I can train locally.

reply
weitendorf
4 hours ago
[-]
To make the most of these architectures I think the key is essentially moving more of the knowledge/capabilities out of the "weights" and into the complimentary parts of the system in a way that's proportionate to the capabilities of the hardware.

In the past couple months there's been a kind of explosion in small-models that are occupying a niche in this kind of AI-transcoding space. What I'm hoping we're right on the cusp of achieving is a similar explosion in what I'd call tool-adaptation, where an LLM paired with some mostly-fixed suite of tools and problem cases can trade off some generality for a specialized (potentially hyper-specialized to the company or user) role.

The thing about more transcoding-related tasks is that they in general stay in sync with what the user of the device is actively doing, which will also typically be closely aligned with the capabilities of the user's hardware and what they want to do with their computer. So most people aren't being intentional about this kind of stuff right now, partly out of habit I think, because only just now does it make sense to think of personal computer as "stranded hardware" now that they can be steered/programmed somewhat autonomously.

I'm wondering if with the right approach to MoE on local devices (which local llms are heading towards) we could basically amortize the expensive hit from loading weights in and out of VRAM through some kind of extreme batch use case that users still find useful enough to be worth the latency. LoRa is already really useful for this but obviously sometimes you need more expertise/specialization than just a few layers' difference. Experimenting with this right now. It's the same basic principle as in the paper except less of a technical optimization and more workload optimization. Also it's literally the beginning of machine culture so that's kind of cool

reply
spacebacon
3 hours ago
[-]
You are on the right track. Check out the Semiotic-Reflexive Transformer (SRT) here.

https://open.substack.com/pub/sublius/p/the-semiotic-reflexi...

reply
hirako2000
4 hours ago
[-]
The claims of the article assumes far more compute and far more VRAM..while the trick enables less back and forth, they don't eliminate it.

I doubt you meant 50M. Rather 50B?

You can only give it a try, but don't get your hopes high on a large context. If their technique works I would guess 8096k context limits would still OOM. 2048 maybe.

I'm extrapolating based on my experiment without this paper's trick to leverage the system memory.

reply
kouteiheika
3 hours ago
[-]
> You can only give it a try, but don't get your hopes high on a large context.

You may or may not know this, but: when training off-the-shelf LLMs (i.e. ones which have a huge vocabulary) what consumes a huge amount of memory usage is calculating the cross-entropy loss (which gets worse the more tokens you stuff in your batch), so always use a fused cross-entropy kernel.

For example, for a Gemma 2 model with 2B parameters at a batch size of 8k this consumes 24GB of VRAM by default (!); you can fuse your cross-entropy loss with @torch.compile and that can cut down this memory usage to something like a few gigabytes, but with a dedicated kernel this becomes a few megabytes.

reply
gavinray
2 hours ago
[-]
I'd not heard of this before, quick search turned up this 2025 post which suggests "fused cross-entropy loss" kernel was integrated into PyTorch:

https://pytorch.org/blog/peak-performance-minimized-memory/

  > "The integration involves modifying the TransformerDecoder module in torchtune to bypass the linear layer computation, allowing the Liger Fused Linear Cross Entropy Loss to handle the forward projection weights. "
Is this the same thing as you discuss above?
reply
kouteiheika
1 hour ago
[-]
Yes.

Although this wasn't integrated into PyTorch itself (but to torchtune, which is a different thing). If you're writing your own training loop you need to use a third-party kernel, e.g. the Liger kernel mentioned in the article, or Cut Cross Entropy (which is much better than the Liger one, although IIRC it has a numeric bug in one of its kernels making the results very slightly off).

reply
hirako2000
2 hours ago
[-]
Activation would still require gigabytes for a few kb context.

There are plenty of techniques to optimise. But the question is what can an rtx 3080 train before OOM. The answer is not that much.

Can barely do quantized fine tuning. Even then, small context.

reply
kouteiheika
1 hour ago
[-]
> Activation would still require gigabytes for a few kb context.

For that you use activation checkpointing, and you can also offload that to the CPU in a smart way to hide the latency. Although, yes, for long context training the activations do dominate the memory usage (and quantizing them degrades things more than just quantizing weights and/or optimizer states).

reply
giancarlostoro
4 hours ago
[-]
> This is pretty awesome. The only compute I have at home is an RTX 3080 with 10 GB of VRAM, so I struggle with training larger models (>40M, 50M params). I get OOM errors and have to optimize a lot.

I'm on the same GPU, its intimidating to me if I even want to bother training anything at all. Do you mind sharing what kind of training you've done with that GPU? :)

reply
kouteiheika
4 hours ago
[-]
This isn't really anything new; I've been doing something like this for quite a while, I just haven't bothered writing a paper. (: Probably anyone who would seriously tackle the problem of "how do I train a huge model on a tiny amount of VRAM?" would come up with something similar.

However, most people in the field don't, because the actual practical utility of training huge models on a single GPU is quite low. (e.g they got 341 tok/s for a 14B model on a single 3090 while with my method I was getting ~1k tok/s on a single 4090; that's still very slow)

Also, there are more tricks one can use to speed up training/lower VRAM usage which they're not using. For example, you don't need any gradient offloading (you can just accumulate the gradients directly into the optimizers' states if you modify your optimizer), you can use Muon instead of Adam (which needs only half of VRAM of Adam), you can use quantization (both for parameters and for the optimizer states; e.g. I found Muon quantized into 4-bit working relatively well), etc.

reply
stevemk14ebr
2 hours ago
[-]
As the saying goes, POC or GTFO

I invented faster than light travel, it was obvious, just didn't write a paper yet either :)

reply
sabedevops
3 hours ago
[-]
Can you take the time to write your methods? I’d be interested in reading it
reply
vlovich123
3 hours ago
[-]
341 is two orders of magnitude faster than your 1 tok/s so it doesn’t seem like their stuff is all that obvious. I also have no baseline for training to know if 341tok/s is slow but it seems speedy for a 3090.
reply
bastawhiz
3 hours ago
[-]
OP said 1k, not 1
reply
SubiculumCode
3 hours ago
[-]
:) Coffee is good
reply
rolandr
3 hours ago
[-]
1k tok/s = 1000 tok/s...
reply
thrawa8387336
2 hours ago
[-]
OOM is log10
reply
bilekas
3 hours ago
[-]
> H200 GPU with 1.5TB host memory,

While yes it's one GPU.. It's not exactly a slim one.

reply
nekusar
3 hours ago
[-]
When the comparison is again 128 H100's , yeah, this is a crazy good upgrade.

And you can rent H100's and H200s for not that much per hour.

reply
p_stuart82
1 hour ago
[-]
$2-4/hr always sounds cheap until you multiply by wall clock and reruns
reply
jeremyjh
39 minutes ago
[-]
Yes but they are getting only 341 tok/s. A 2.5 trillion run would take over 200 years.
reply
anshumankmr
3 hours ago
[-]
>cries in RTX 3060
reply
drob518
2 hours ago
[-]
I’m curious how this technique works, or not, with unified memory architectures such as Apple’s M series. It seems like it’s relying on using overlapping processes to help speed things up, but I would assume that having everything unified in main memory such that you don’t have to transfer everything back and forth to the GPU would also have some advantages. Can someone wiser explain this to me?
reply
WithinReason
5 hours ago
[-]
I was wondering how well this would work :) You can definitely push this further, the question is: how well can the gradients and updates compress?
reply
ilaksh
4 hours ago
[-]
How long would it actually take to train a 120B model on an H200? What if you have 8?
reply
1aurent29
4 hours ago
[-]
sounds very similar to https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_... i wonder how much this could be replicated using only this pytorch primitive
reply
ur-whale
53 minutes ago
[-]
Why is it no one ever talks about the one thing no one can get their hands on except the big labs ?

I'm talking about the training set.

Sure there are some open sets out there.

But my guess is they are nowhere near what OpenAI, Google and Anthropic are actually using.

Happy to be proven wrong.

reply
olliepro
5 hours ago
[-]
This would likely only get used for small finetuning jobs. It’s too slow for the scale of pretraining.
reply
onion2k
5 hours ago
[-]
It’s too slow for the scale of pretraining.

There isn't really such a thing as 'too slow' as an objective fact though. It depends on how much patience and money for electricity you have. In AI image gen circles I see people complaining if a model takes more than 5s to generate an image, and other people on very limited hardware who happily wait half an hour per image. It's hard to make a judgement call about what 'too slow' means. It's quite subjective.

reply
jandrese
5 hours ago
[-]
If it would take so long to train that the model will be obsolete before the training is finished that might be considered too long. With ML you can definitely hit a point where it is too slow for any practical purpose.
reply
ismailmaj
4 hours ago
[-]
Obsolete because of what? Because with limited hardware you’re never aiming for state of the art, and for fine-tuning, you don’t steer for too long anyway.
reply
jandrese
4 hours ago
[-]
Because there is a new model that is better, faster, more refined, etc...

If your training time is measured in years or decades it probably won't be practical.

reply
jwilber
4 hours ago
[-]
That’s just playing semantics. Nobody is talking about, “objective facts” or need define them here. If the step time is measured in days, and your model takes years to train, then it will never get trained to completion on consumer hardware (the entire point).
reply
greenavocado
5 hours ago
[-]
So distribute copies of the model in RAM to multiple machines, have each machine update different parts of the model weights, and sync updates over the network
reply
olliepro
2 hours ago
[-]
decentralized training makes a lot more sense when the required hardware isn't a $40K GPU...
reply
atlgator
3 hours ago
[-]
The GPU is no longer the brain, it's the hand. The brain is your RAM. Suddenly that 256GB DDR5 build your wife questioned is 'research infrastructure.'
reply
l1n
5 hours ago
[-]
Seems similar to Microsoft DeepSpeed.
reply
bee_rider
4 hours ago
[-]
The compare against “DeepSpeed ZeRO-3” apparently.
reply
jazzpush2
3 hours ago
[-]
FWIW Zero-3 refers to a common strategy for sharding model components across GPUs (commonly called FSDP-2, Full Sharded Data Parallel). The "3" is the level of sharding (how much stuff to distribute across GPUs, e.g. just weights, versus optimizer state as well, etc.)
reply