The problem: Right now, 90% of LLM workloads run on NVIDIA GPUs, but there are equally powerful and more cost-effective alternatives out there. For example, training and serving Llama 3.1 on Google TPUs is about 30% cheaper than NVIDIA GPUs.
But developer tooling for non-NVIDIA chipsets is lacking. We felt this pain ourselves. We initially tried using PyTorch XLA to train Llama 3.1 on TPUs, but it was rough: xla integration with pytorch is clunky, missing libraries (bitsandbytes didn't work), and cryptic HuggingFace errors.
We then took a different route and translated Llama 3.1 from PyTorch to JAX. Now, it’s running smoothly on TPUs! We still have challenges ahead, there is no good LoRA library in JAX, but this feels like the right path forward.
Here's a demo (https://dub.sh/felafax-demo) of our managed solution.
Would love your thoughts on our repo and vision as we keep chugging along!
FWIW, if this helps prioritize: personally I'd find LoRA training for Llama 3.1 most useful (which it sounds like currently isn't well-supported with Felafax?) since with something like vLLM you can serve large numbers of LoRAs that share the same underlying GPU resources (assuming they're based on the same base model), vs full finetunes where each model will need to deploy on its own set of GPUs. In general I would guess that full finetunes are going to be less cost effective for most enterprise use cases: finetuning — whether full-finetuning or PEFT — generally improves only task-specific performance, so assuming you've got more than one task you want to use a model for in your business, it'll pretty quickly become dramatically cheaper to do the tasks with LoRAs rather than full finetunes unless you're saturating the boxes for each specific task. So, I'm hoping you guys build support for LoRA training with JAX in addition to finetuning!
btw, we have LoRA supported with Llama3 PyTorch-XLA model. Check that out in meanwhile.
Let's assume you're doing a single-epoch LoRA training run. A single H100 should be enough to train Llama 3.1 8B, and it should crank through 264MM tokens in a couple hours, IMO. Since you're not doing multi-GPU training, a PCIe H100 should be fine — you don't need the slightly pricier SXM H100s — and the PCIe versions go for about $2.50/hr on Runpod.
So, about $5 for a custom model, that's probably the best in the world at whatever your task is! (Even if it might be a little dumber at other tasks.) Insanely cheap when you think about it.
TPUs won't beat H100s on price for on-demand personal use cases, but for reserved capacity (i.e. businesses) they're slightly cheaper.
So it would seem the cost really becomes converting/curating the data into a usable format first.
Conclusion is at the bottom, but TLDR was TPUs were 33% cheaper (performance per dollar) and JAX scales very well compared to PyTorch.
If you are curious, there was a thorough comparison done by Cohere and they published their paper https://arxiv.org/pdf/2309.07181 -- TPU+JAX turned out to be more performant and more fault tolerant (less weird errors).
When you say this, you should specify which Nvidia GPU you mean (I assume h100 SXM) and that price you are assuming for such GPU.
One can't simply compare based on the on demand price on GCP, because the Nvidia GPUs there are extremely overpriced.
Runpod is ever-so-slightly cheaper than Google TPUs on-demand on a per-GB basis: about 4.3 cents an hour per GB for Runpod vs 4.4 cents an hour per GB for a TPU. But let's look at how they compare with reserved pricing. Runpod is $2.79/hr with a 3-month commitment (the longest commitment period they offer), whereas Google offers v5p TPUs for $2.94/hr for a 1-year commitment (the shortest period they offer; and to be honest, you probably don't want to make 3-year commitments in this space, since there are large perf gains in successive generations).
If you're willing to do reserved capacity, Google is cheaper than Runpod per GB of RAM you need to run training or inference: Runpod is about 3.4 cents per GB per hour vs Google for about 3.09 cents per GB per hour. Additionally, Google presumably has a lot more TPU capacity than Runpod has GPU capacity, and doing multi-node training is a pain with GPUs and less so with TPUs.
Another cheap option to benchmark against is Lambda Labs. Now, Lambda is pretty slow to boot, and considerably more annoying to work with (e.g. they only offer preconfigured VMs, so you'll need to do some kind of management on top of them). They offer H100s for $2.99/hr "on-demand" (although in my experience, prepare to wait 20+ minutes for the machines to boot); if cold boot times don't matter to you, they're even better than Runpod if you need large machines (they only offer 8xH100 nodes, though: nothing smaller). For a 1-year commit, they'll drop prices to $2.49/hr... Which is still more expensive on a per-GB basis than TPUs — 3.11 cents per GB per hour vs 3.09 cents per GB per hour — and again I'd trust Google's TPU capacity more than Lambda's H100 capacity.
It's not dramatically cheaper than the cheapest GPU options available, but it is cheaper if you're working with reserved capacity — and probably more reliably available in large quantities.
What matters is steps per $ and to some degree also speed (I'm happy to pay premium sometimes to get the fine tuning results faster).
At what scale were you able to get a significant discount and how much?
Most people will be (full) fine tuning on 8xh100 or 16xh100 for few days at a time.
We spent nearly 4 weeks getting PyTorch XLA working on TPU. Hope that answers your question!
That said, in terms of single-GPU speed, we believe we would be behind but not too far off, thanks to JAX+TPU's more performant stack. Additionally, we can do larger-scale multi-node training on TPUs.
There are still more optimizations we need to do for Llama 3.1, such as adding Pallas memory attention kernels, etc
[0] https://github.com/felafax/felafax/blob/main/llama3_pytorch_...
Also, calculating GPU costs is getting quite nuanced, with a wide range of prices (https://cloud-gpus.com/) and other variables that makes it harder to do apples-to-apples comparison.
I’m curious about your reported 30-70% speedup.
No, we haven't run our JAX + XLA on NVIDIA chipsets yet. I'm not sure if NVIDIA has good XLA backend support.
At the bottom, it shows the calculations around the 30% cost efficiency of TPU vs GPU.
Our range of 30-70% is based on some numbers we collected from running fine-tuning runs on TPU and comparing them to similar runs on NVIDIA (though not using our code but other OSS libraries).
They have some good uses but LLMs aint it
They have almost nothing in common with Cloud TPUs.
Google also has Cloud TPUs, which are their server-side accelerators, and this is what we are initially trying to build for!