ENFR
8news

Tech • IA • Crypto

TodayMy briefingVideosTop articles 24hArchivesFavoritesMy topics

Scale AI with Google's TPU software stack

GoogleGoogle for DevelopersMay 21, 2026 at 11:51 PM37:52
Audio player
0:00 / 0:00

TL;DR

Google is expanding its AI infrastructure with specialized TPUs and a full software stack aimed at making large model training, fine-tuning, and inference more efficient and accessible.

KEY POINTS

Shift Toward Specialized AI Hardware

Google is moving away from general-purpose systems toward specialized chips tailored for different AI workloads. The TPU v8 family separates training and inference tasks: TPU 8t focuses on high-throughput training at scale, while TPU 8i is optimized for low-latency, cost-efficient inference. This reflects a broader industry trend where inference—especially for “thinking” models that consume large token volumes—has become a dominant source of computational demand.

Inference Becomes Central to Model Intelligence

Advances in reasoning models are shifting more computational complexity into inference rather than training. These systems generate longer chains of reasoning, increasing the importance of efficient runtime execution. As a result, optimizing memory usage, scheduling, and hardware utilization is now as critical as model design itself.

vLLM Improves Serving Efficiency

The vLLM inference engine addresses key bottlenecks such as memory fragmentation and irregular request patterns. Its PagedAttention system virtualizes KV cache into fixed-size blocks, improving utilization, while continuous batching dynamically schedules token-level workloads. Additional features like prefix caching enhance performance in conversational and agent-based applications, enabling reuse of shared computation across requests.

Cross-Platform Portability for Developers

vLLM provides a unified backend supporting both JAX and PyTorch, and can run on TPUs and GPUs without requiring changes to application code. This portability allows developers to switch hardware environments without rewriting serving layers, lowering operational complexity.

Demonstration of Large-Scale Model Serving

A 31 billion-parameter Gemma 4 model was deployed on a system with 8 TPU chips, achieving near-full memory utilization. The setup demonstrated parallel execution, asynchronous request handling, and profiling tools for performance tuning, highlighting how large models can be efficiently served in production-like environments.

Speculative Decoding Boosts Speed

A new implementation of diffusion-style speculative decoding significantly accelerates inference. By allowing a smaller model to predict multiple tokens in parallel and having a larger model verify them, the system achieves up to a 3× speed improvement compared to traditional autoregressive decoding.

Tunix Simplifies Post-Training and Reinforcement Learning

The Tunix framework enables lightweight fine-tuning and reinforcement learning on TPUs, even using free resources like Kaggle or Colab. It supports techniques such as supervised fine-tuning (SFT), reinforcement learning, and knowledge distillation, allowing smaller models to inherit capabilities from larger ones using richer training signals.

Efficient Fine-Tuning with Smaller Models

A 4 billion-parameter model was fine-tuned using reinforcement learning to perform multimodal tasks similar to larger models. After training, it could run on a single TPU chip instead of multiple chips, demonstrating significant gains in efficiency without sacrificing functionality.

MaxText Enables Scalable Pre-Training

MaxText provides open-source, production-ready configurations for training large models across thousands of chips. Built on JAX and OpenXLA, it includes optimized “recipes” for models like Gemma, Mistral, and DeepSeek, allowing developers to scale from single-machine experiments to distributed training without redesigning workflows.

JAX and OpenXLA Power the Stack

The ecosystem relies heavily on JAX, which offers composable transformations for differentiation, compilation, and parallelism. Beneath it, OpenXLA manages low-level optimization and distributed execution, enabling efficient coordination across massive TPU clusters.

TorchTPU Expands PyTorch Compatibility

Google is developing TorchTPU, a native PyTorch stack for TPUs built in collaboration with Meta. The goal is to allow developers to run PyTorch models on TPUs with minimal or no code changes, improving accessibility for teams already invested in CUDA-based workflows.

Kinetic Reduces Infrastructure Overhead

The Kinetic project simplifies deployment by automating TPU cluster setup through minimal configuration, allowing developers to focus on code rather than infrastructure management.

CONCLUSION

Google’s strategy combines specialized hardware with an increasingly mature software ecosystem, aiming to make large-scale AI development faster, more efficient, and more accessible across both training and inference workloads.

Full transcript

[MUSIC PLAYING] GIRIJA SATHYAMURTHY: So over the past several years, Google has been driving an incredible pace of innovation in AI. And when you look across this timeline, you will see an evolution of models and modalities all the way from Gemini to Gemma, Veo, Imagen, and Omni that was announced just yesterday. So what's incredible about this timeline is that it's not just the cadence of these releases but the breakthrough capabilities that are being unlocked with every launch, things like multi-modal understanding, reasoning, and agentic workflows. So today, we want to go one layer deeper and talk to you about what powers these models, and not just from a silicon standpoint, but from a full software stack standpoint, and talk to you about some of the tools that you can use as you build, train, fine tune, serve state-of-the-art models on Google Cloud TPUs. JOSH GORDON: OK, awesome. So as you know, there's roughly four stages to building a large language model. First, you would design your neural network. And to do that, you'd use a framework like JAX. Then you have pre-training, post-training, and inference. In pre-training, you're teaching the model basically to predict the next token. In post-training, you're teaching the model to be useful. You could do something like supervised fine tuning to teach it to follow instructions. And then you could do something like reinforcement learning to teach it to reason. And then after that, you need to serve the model to users, or do inference. And what's really cool is if you look at this chart, you can look at roughly where the intelligence is coming from. And what's really cool recently is you'll see that a lot of the intelligence is actually coming from inference. And if you think about why that is, it's because now we have thinking models that are consuming lots of tokens as they reason over your problems. And so what's cool is we have these two distinct phases of compute. We have training compute, and we have inference compute. And the needs are a little bit different. And we've designed our hardware stack, our TPU stack with that in mind. GIRIJA SATHYAMURTHY: Yeah. So as Josh mentioned, as you can see here, there is a clear separation from an architectural standpoint between systems that specialize in training workloads and those that specialize in inference workloads. So TPU v8 is our latest generation of TPUs that we announced at Cloud Next. And TPU 8t is highly optimized for large scale training workloads. So the focuses here are things like maximizing your throughput and scaling efficiency. And TPU 8i is great for inference workloads. So the focuses are things like low latency and cost efficiency. So we are moving away from one general purpose system into hardware that is specialized for the needs of different types of workloads. But hardware is just one part of the puzzle. So let's talk about some of the software tools that we mentioned earlier. JOSH GORDON: OK, awesome. So in this talk, we'll take you through a bunch of software that you can try at home. We're going to start with inference because we just talked about how important that is for thinking models. A really cool thing to say about inference-- pre-training a model from scratch takes a lot of compute. But if you're working on inference and also post-training, that's absolutely something you can do with the free TPUs that are available in Kaggle or on Colab. So we'll start with a demo of serving Gemma 4 using vLLM on TPUs. And then we'll move into another demo of doing some post-training with a different version of Gemma using Tunix, which is our lightweight post-training framework. Then quickly, we'll go through MaxText, which has high performance reference implementations of all sorts of large language models implemented in JAX. And these will scale out of the box to thousands of chips, which are super cool. And then we'll talk about frameworks for actually building models. So JAX, which is used to build Gemini and Gemma. And then we'll talk about TorchTPU, which is really cool. It's a new software we're working on to make it really easy for you to run your PyTorch code on TPUs without code changes. GIRIJA SATHYAMURTHY: Awesome. JOSH GORDON: And to get started, Rob will give a demo on vLLM. Sorry. GIRIJA SATHYAMURTHY: I will do a quick intro, and then you have to hold for 5 minutes for the demo. JOSH GORDON: Sorry, Rob, I jumped the gun. GIRIJA SATHYAMURTHY: JOSH GORDON: No. Let's start with inference. So LLM serving is no longer just about compute. In fact, at scale, some of the harder problems are usually memory management, scheduling, and hardware utilization, and also being able to abstract away all of these complexities. This is where vLLM comes in. Let's talk about some of these challenges, and then we'll talk about how vLLM helps you address those. So as context lengths grow and concurrency increases, KV cache management becomes one of the biggest bottlenecks in inference. So vLLM uses PagedAttention. It virtualizes KV cache into fixed-size blocks, so this tremendously helps improve utilization. But this is just the foundation, and there are several optimizations on top as well. For example, prefix caching is ideal for agentic workloads and conversational workloads. The idea is for several requests to have one common prefix. So instead of the system having to repeatedly calculate KV cache, it can simply use whatever was previously calculated. So let's talk a little bit about scheduling now. Inference traffic is highly irregular. You are constantly dealing with different prompt lengths, generation times, arrival patterns, and so on. vLLM uses continuous batching. So instead of using static, request level batches, your system dynamically schedules token level work at runtime. This obviously helps with scheduling, but also helps a lot with hardware utilization, which brings us to the third bucket that we were talking about. So especially on TPUs, these workloads benefit a lot from memory layouts and hardware-aware execution strategies. And these work naturally with vLLM's execution model and helps you maximize throughput. So the beautiful thing about vLLM is also that all of this is exposed through a unified inference backend. It supports both PyTorch and JAX frameworks, and also helps you run your workloads on both GPUs and TPUs. It also supports a wide range of model architectures. So the goal here is portability and helping developers choose the framework that is right for them, and also being able to move across hardware backends very easily. So the key thing here is that, as a developer, you are never having to rewrite your application layer or serving stack. So with this intro, now let's see vLLM in action. Over to you, Rob. ROB MULLA: Thanks, Girija. That was a great intro about running inference on TPUs. So in this demo, I'm just going to show you how easy it is to get up and running, running the latest Gemma 4 31 billion-parameter model on TPUs. So I am running this notebook on a virtual machine that has 8 Trillium chips. So we can use TPU info, which is a package for letting us see what the chips are on the machine and the utilization. So we can install vLLM TPU just with pip install vLLM TPU. However, I am running this model, and we recommend running these models using our pre-built Docker containers. This makes it super easy, so you don't have to install all the different packages and dependencies. And it will just run out of the box. We do also have a TPU recipes repo for various models that are supported on different chips. This is not the full extent of models supported for vLLM on TPU, but it's a good place to start if you're looking to run some models on TPUs. So in this Docker Compose file that I've written, there's a few things I want to note. First is that I've enabled profiling. I'm going to show you this later, but this allows you to see the low level execution that occurs on the TPU. It's not great for running in production, but when you're debugging, it's really helpful. I also set tensor parallel size to 8, and that's really all you need to do to make sure that it's run across all the TPU chips. And I'm also using the Gemma 4 tool called Parser, and I'll show you later how that works. So just to show you here that we do have our Docker container up and running with the latest nightly build of vLLM TPU. And checking out some of the logs here, we can see that it's successfully booted up. And I started this up earlier just to save us from waiting for the compile time. But we're running the 31 billion-parameter Gemma 4 model. And these commands let us know that the process is up and running. So I'm also going to, just during this demo, show these logs here at the bottom here in our console so we can see what's going on. There we go. So now as we call the model server, we can see some of the stats running. Just to show you here again using tpu-info, now we can see the memory utilization across all of our chips is nearly maxed out. Now, this model could actually fit on four chips, but we're using eight because a lot of times you want to have a lot of overhead for the KV cache that your model might need to consume. And then from here on, the great part about the fact that we're using vLLM is that the frontend and your software design is going to be pretty much the same as if you were running vLLM on a GPU. So you can have the same frontend software written, and you could change the backend hardware that you're running on. And it should operate pretty much the same. So let's just send a single query to ask it to write a haiku about Google Cloud TPUs. But we want to do more than just sending a single request. And Girija mentioned this before, but one of the really great things about what we've built into vLLM on TPU is this thing called Ragged PagedAttention. So this really does a lot of things. It eliminates the KV cache memory fragmentation. It enables continuous batching, so it's really good with asynchronous requests. And then it has a bunch of custom built kernels that are specific for running on TPUs. And one of those is the fact that it automatically will change between running from prefill-heavy to decode-heavy types of workloads. So a prefill-heavy workload might be a really long input, like images or long context and then requesting a smaller output. And decode would be a smaller input requesting a really long output. So we're just going to send two simultaneous requests, 16 asynchronous requests, to vLLM running on TPU. And it finished relatively quickly. But since we have profiling running, we can now go over to our profiler and actually see all the low level execution here. This is really helpful if you're tuning your model or if you want to really get in. If you're a power user and you're going to be writing your own kernels, XProf allows you to really explore all the execution on the TPU. All right, so just to put this all into context of an actual application that you might want to build, I'm going to show how we could build a food logging assistant. So the idea here is we have an image and some text about some food that we might take a picture of. We want the agent here, Gemma 4, to identify the contents, and then query a database to get all the nutrient information, and finally generate a summary. So Gemma 4 is great for this because it's multi-modal. So we can send it an encoded image of some food, like this, and ask it to describe the contents of the food. And Gemma 4, is really good at doing this sort of summarization. So here we have a complete breakdown of the food in that bowl. And then Gemma 4 has also been trained to work really well with tool calling. So tool calling, here, I'm just going to simulate a fake nutrition database that we want to give the model access to. And all we have to do is define this lookup nutrition tool and send it in our request to the endpoint. And then the model knows that it has access to this nutrient database that it can call. So if we call it and ask it for the calories in a bunch of different foods, we can see that the model has then called the tool a number of times. So just to wrap it up and to show you in full view, I vibe coded up with Gemini this front end for what it might look like. And you can just run a query here. See how it's iteratively, as an agent, using those tool calls to call the database, identifying the different things in the image. And then we can look this over. We could tweet it and log it. And later, I'm going to show you how we could fine tune a model to do something very similar. So now I'm going to turn it over to Josh, who is going to talk about some of what's new in vLLM on TPU. JOSH GORDON: Thanks, Rob. That was super ace. ROB MULLA: Thanks. [APPLAUSE] JOSH GORDON: OK, so what's new in vLLM? Can we switch back to the slides, please? Thanks. OK, so lots of reproducible recipes on GitHub. Rob showed you, super cool. Another really cool thing I wanted to mention, and this is one reason I've loved working in open source all these years, we got a really amazing PR from UCSD. And this is for diffusion-style speculative decoding. If you're new to speculative decoding, the idea is kind of crazy. I remember when I learned about it. I was like, wait, that's kind of wild. The idea is if you have a very large model, it's computationally intensive to predict the next token. In speculative decoding, the idea is that you have a smaller model, and the smaller model races ahead and it predicts tokens, basically guessing what the larger model is going to say. And for various reasons, it's actually faster to have the large model verify the output from the smaller model than it is to actually predict the tokens. So it's a crazy idea. Now, usually when you do speculative decoding, you do it autoregressively, which is just a fancy way of saying one token at a time. And if you're doing it that way, there's still a slowdown because it's one token, the next token, the next token, the next token. What this pull request did, it's diffusion style speculative decoding. So instead of predicting the next tokens one at a time, you predict them all at once using a diffusion model, which eliminates that bottleneck. And then we incorporated this into vLLM TPU inference, and we got about a 3x speedup on a bunch of tasks. So it's super cool. Hugely grateful to, I believe, Professor Zhang at UCSD. So super awesome. Anyway, open source is great. So now let's take a quick look at Tunix. And Tunix is our post-training and reinforcement learning framework. It's lightweight. You can also absolutely use Tunix at home with the free TPUs on Colab and Kaggle. And we just wrapped up a really cool Kaggle hackathon with Tunix. We got thousands of participants. It was really cool. And the idea was to take an older model of Gemma that didn't know how to reason and then teach it to solve math problems and show the reasoning traces. So it was super cool. One really nice thing about Kaggle, too, is there's a whole bunch of example notebooks you can check out. So all the folks that submitted their projects, that basically serves as this huge knowledge base. So it's a really good place to learn. So Tunix is state-of-the-art. It supports a whole bunch of stuff. It's really awesome. Here's just three quick things it supports. So SFT-- and that's something you could do to, say, teach a model to follow instructions. That typically would be the first stage in the reinforcement learning pipeline. After that, it supports a whole bunch of reinforcement learning stuff. Rob will show a demo of GRPO in a sec, which is something that we used in the Kaggle hackathon. It also supports knowledge distillation, which is another one of these crazy ideas. So oftentimes, if you have a big model, a question you could ask is, can I train a smaller model that has similar performance to the large model? And one way you can do that, often when you're training the large model, you might be training it to predict the next token. So the signal you have to learn from is, is the next token right or is it wrong? But after you have a large model trained, you can actually use the large model to train a smaller model, and you have a richer signal. Because when the larger model predicts a token, you want the smaller model to predict the same token. But in addition to knowing the token, you can actually see the probability distribution from the larger model. So you get not just the token that it predicted, but you can see the scores for all the other tokens that it could have predicted. And that's a much richer signal that you can use to train a smaller model, faster, to have good performance. So there's all sorts of these really cool, magical ideas. And Tunix supports that, too. It also supports some really new, cool, and advanced stuff. So if you're interested in agentic workflows-- so the idea would be, let's say that you wanted to train a software engineer model, like a software engineer agent, to actually solve real problems on GitHub. Now you have these very long-running agentic workflows. And Tunix supports post-training for those as well. We've got a bunch of cool examples on GitHub that you can check out. It's super cool. And now Rob's going to show a demo of GRPO with Tunix. ROB MULLA: Thanks, Josh. JOSH GORDON: Thanks. ROB MULLA: OK, I'm back. So previously, I showed how we could, out of the box, use Gemma 31-billion parameter model to run as this food tracking assistant. But what if we wanted to ensure that it works in all the different scenarios that we want it to? And what if we wanted to use a smaller model that might be more efficient, that we wouldn't have to run on as much hardware? Well, in this demo, I'm going to show you Tunix, which is our library for fine tuning and doing reinforcement learning on models on TPUs. And it uses JAX underneath the hood. And we're going to fine tune a 4 billion-model parameter to do the same thing as what we showed in the previous one. I'm going to quickly walk through the steps. So not to get into too much of the details here, but we're going to define the two different models that we need for this, the actor and reference model. And we're going to set up a data set using Grain, which is Google's data pipeline library. Then we're going to set up some scoring rewards. And finally, we're going to train the model and then put it in production. All right, so again, I'm going to show you here on tpu-info that we're running on a machine with eight chips. And then we're going to define our actor and reference models. So with GRPO, which is Group Relative Policy Optimization, which is the technique we're using, we actually want two copies of our model. We'll have our actor model. That's the model that we'll be updating the weights of to make better at our task. And we have our reference model, which we use to make sure that the actor model isn't deviating too far from the original model. And creating these and distributing these across our TPUs is super easy with Tunix and JAX. So I'm just going to do that here. And then we can see, after this is running, that the memory use does have them sharded across all eight of our chips. We do leave a lot of overhead here, memory, because doing these reinforcement learning type of training jobs does require a lot of memory usage on the TPUs for doing these training rollouts. Next, I'm going to create the data set that we'll use to train this on. And the great thing about reinforcement learning is we don't need a large, labeled data set. But we do need to simulate what the model will see, as it's acting as the agent in our program. So with Grain, we can import an existing Hugging Face data set. I'm using the Food 101 data set, which has over 100,000 images of food. And then I'm adding a pre-processing step, which will then simulate what the model might see when it's acting as our agent. And then we can use Grain's Map data set to pipe these all together so that it will efficiently distribute this to the different chips as we're training our model. Just to show you here, here is one of the examples. Some yummy looking pizza there, and then the input prompt, and an example what it might see when it needs to find all the different ingredients in this food. So now that we've set up our data set, we actually will define some rewards. Instead of having strict labels here, we'll just define rewards that the model will try to optimize against. And for this, I had a number of those making sure it correctly called the tools, that it had a correct response from the database using this thought process, and also correctly formatting the output at the end of the process. But the real thing you need to know about this-- this is the important part if you're designing these-- to really think through how to make these rewards reach what you want your model to end up acting like once it's trained. OK, so here's the power of Tunix. At this point, we've created our actor model, our reference model, our data set, and our rewards. And all we have to do is use this reinforcement learning cluster. And Tunix will take care of all the distribution, coordinating the actor reference and rollout across all of our TPUs. And then lastly, we create this trainer, and we train on it. So we provide this cluster our rewards we defined earlier, and we train on our data set. Now, this does take 30, 45 minutes to run all the way through. So I'm not going to actually do that in this demo. But we are going to look at some of the metrics and results. And this great thing about Tunix, it handles all the logging of checkpoints and allows us to easily go into TensorBoard and see, on different experiments, how our model is responding to that and how it's actually improving, achieving the rewards that we set up. All right, so lastly, let's just jump back over to this example UI. And I'm going to switch from Gemma 4 to using our fine-tune model. Now, the great thing about this is the model is much smaller. So I'm able to run it on a single Trillium chip, instead of the eight or four that it would require for the larger model. And then it works. Thanks. So now I'm going to turn it over to Girija again. She's going to talk more about MaxText. [APPLAUSE] GIRIJA SATHYAMURTHY: It was a cool demo. Thanks so much, Rob. Awesome. Let's talk a little bit about pre-training. So one of the things that you will notice across both post-training and pre-training is that the challenges are pretty similar. You're still dealing with distributed systems execution and memory efficiency or checkpointing and so on. So the real question becomes, how do you make frontier scale training much more reproducible so your developers are not always having to start from scratch? This is where MaxText comes in. MaxText is an open-source reference implementation that helps you train large scale foundational models on TPUs. But MaxText is much more than a simple framework. It gives you a lot of battle-tested training configurations across a wide range of model families, like Qwen, Mistral, Gemma, DeepSeek, and so on. So the idea is that your developers are not starting from scratch and can instead use these battle-tested recipes and adapt from there. Under the hood, MaxText is powered by JAX and XLA. JAX offers the functional programming model that helps you with model definitions and parallelism strategies. And XLA helps with compilation and also helps you with distributed systems execution in a very optimized manner on TPUs. And the combination is great because it allows you to take the same training workflow and go from a single-host experimentation all the way to multi-part production runs. And you don't have to worry about changing your code structure or your training workflow. MaxText encodes many of our hard-learned lessons in large scale training into these battle-tested recipes. So the idea is to make large scale training much more reproducible, accessible, and just less infra-heavy, so you can just focus on model development. We talked about some of these updates a little bit just now. So MaxText supports, like I said, a wide range of model families, all the way from Qwen, DeepSeek, Gemma 4 is our latest one. There's a couple more that are landing pretty soon as well. It also has some robust capabilities, like multi-token prediction, and performance-enhancing capabilities, like support for custom kernels and multi-tiered checkpointing. There's a link up here to our MaxText repo, which hosts all of our recipes, so please do check them out. Now that we have talked a little bit about the different tools you can use at different stages of the model development lifecycle, let's spend a few minutes talking about frameworks. Josh, you want to come over and help us with this? JOSH GORDON: Yeah, I'm back. Thanks. OK, so one thing that can help when we're talking about TPUs, and maybe we could have showed this earlier, it's nice to just visualize, to get a sense of the scale that we're talking about. Can we switch back to the laptop, please, and I'll show a quick demo? So we have a TPU visualizer, which came out a couple months ago. And if you're brand new to TPUs, I just want to show you what this can look really quick. So with this particular generation of TPU-- and just for visualization, we happen to have a v5p selected. It's OK. What we're looking at here in red, that's a Linux virtual machine. So you could use something like Google Compute Engine, spin up a VM. And then in this case, the blue, we've got two TPU chips attached to that. So it's just one machine you can SSH into. You've got two chips. What's cool about TPUs, of course, is you can attach more chips. So now we've got four chips attached. And there's lots we could talk about, but just a couple of things to glance at. You can notice that the TPUs are directly connected to each other. So in the blue at the top, the TPUs can actually send data back and forth without going through the virtual machine itself. And that's for speed. And there's all different kinds of configurations. So a really common one would be a cube. So this is a 4x4x4 arrangement of TPUs. And what's cool in a cube here, there's different types of connections between the TPU chips. There's copper when they're super close. And there's optical when they're a little bit farther away. But the goal, and the reason we have lots of connectivity like this, is we want to minimize the distance between different chips in the cube, or the cluster, for fast interconnect. Anyway, TPUs scale up really, really fast. Oh, by the way, of course, here we've got a whole bunch of virtual machines. So now, obviously, we've got a lot of them. When you have a lot like this, we've got all sorts of software which is out of scope for this talk. But you can use things like Kubernetes to help you really manage a cluster like this. And then what's cool about TPUs in this generation-- oops. This is just a pod slice. So now we have thousands and thousands of chips that are all connected together. And the scale is just wild. I'm a Star Trek fan, so this reminds me almost of the Borg. This is actually a little bit less than a pod. But what's interesting is, beyond this scale, then we move into, like, a data center network. And so you can have many, many, many pods like this connected together. And the idea that I just want to show you, when you think of a TPU, you can think of one chip that's plugged into a Linux VM. And you can also think about almost having an ocean of compute. And what's really cool about the software stack and our frameworks, a lot of them are designed to make a lot of the magic that goes into all this complicated-- Oh my God, how do you coordinate thousands and thousands of chips? And how do they all talk to each other? And how do you do that efficiently? And how do you make sure that all the chips are working at the same time so none are idle, and all that stuff? Things like MaxText and JAX handle all of that behind the scenes for you. And you can optimize it, but we've done a huge amount of work to make a lot of that transparent. And so with something like JAX, that we'll talk about in a minute, you can write code that feels like you're working on a Linux VM. But really, that code can be running on something like this, which is absolutely wild. So it's super cool. You can find this in the TPU docs on Google Cloud. It's also on bit.ly/tpu-viz, for visualizer. OK, let's switch back to the slides, please. Thanks. So I just wanted to say very briefly, we'll talk about JAX and TorchTPU in a second. Beneath both of those is our open-source compiler called OpenXLA. And the main thing I wanted to say about this is this does a lot of the magic, which makes all that inter-chip communication possible without you having to think about it. So it's just this super powerful open compiler. We've actually got a whole upcoming conference on this, that I'll talk about in a bit, if you want to go much deeper. But yeah, OpenXLA is awesome. It does a lot of stuff for you, makes your life easy, which is what I want. OK, here's the clicker. GIRIJA SATHYAMURTHY: Yes, thank you. Awesome. So now with that little bit of foundation, I think talking about JAX and PyTorch will make a lot more sense. So let's start with JAX. Gemini is built on top of JAX. And the reason is because JAX offers a very clean, programmable model that's great for large scale ML systems. At its core, it's just like NumPy with some composable transformations. For example, Grad automatically lets you compute gradient computations using simple Python functions. Jit helps with tracing and compiling using XLA and converts your Python functions into optimized accelerator programs. There's also Vmap that helps you express parallelism cleanly, and so on. So, as you can see, at its core, JAX is pretty minimal in the sense that you're writing simple, straightforward Python functions and then progressively adding transformations to it to make it more differentiable, vectorized, compiled, and scalable across accelerators. So instead of having to learn a heavy, large framework, all you're doing is writing straightforward Python functions and progressively adding transformations, like we talked about. So you can see here JAX is pretty minimal, like we talked, about at its core. And so what makes it great is the ecosystem around it. In real world systems, you need much more than the composable transformations that we talked about, because you need things like optimizers, data pipelines, checkpointing, and so on. This is where the ecosystem comes in. It's a set of libraries that provide the higher level model building abstractions for a variety of things, like we talked about. For example, Flax is a very popular neural network library that provides these abstractions to help you with modules and parameter management. It's especially useful when your model stops being a simple Python function and starts having things like nested loops and recurring blocks. A couple of other examples-- for example, Orbax is great with checkpointing and state management. Grain is great for input pipelines and data loading and so on. So you get the idea. JAX is pretty minimal, but together with its ecosystem, it's ideal for large scale ML systems. We'll also quickly touch on TorchTPU. Like Josh mentioned earlier, this is a complete, ground-up rebuild of the PyTorch stack on top of TPUs. And we are building this in close partnership with Meta because we want to stay aligned to the overall direction of PyTorch and of the community. I'll talk about some core principles that we are building this around to give you an idea for how this is going to shape up. The first one is that it's fully native, meaning there are no new APIs. So if you know PyTorch, you already know TPUs. The second is around portability. So the idea here is to help with rapid experimentation so researchers and developers can just run models right out of the box, and you're not dealing with upfront portability issues. The third is making sure there are well-lit paths for performance, using things like torch.compile and profiling tools and kernel support. And the fourth is just making sure we are closely aligned with the ecosystem and enabling all PyTorch developers to easily run their workloads on TPUs. Many ML developers are using Cuda and PyTorch workflows today. But with increasing heterogeneous backend systems becoming a reality, we want to help developers to easily move their workloads onto TPUs without requiring them to rewrite their application layer. In many cases, the actual code changes that are required are pretty minimal, as you can see. And if you're interested in learning more, stay tuned and sign up to our newsletter so you can keep up with the updates here. JOSH GORDON: Thanks, Girija. OK, also, one last really cool project before we wrap up. So this is Kinetic. It's a new project from the Keras team. And what I really like about this is it's designed to make running your code on TPUs a lot easier just using a decorator. So as you can imagine, when you're setting up stuff on Google Cloud, there's some DevOps work involved. And what Kinetic does is it handles that for you. So it can go ahead, and it can set up a cluster that's ready to run your code on TPUs without you having to configure a bunch of stuff manually. So really nice new project. We've been doing some hackathons and some sprints. It's been really popular. So I think it's neat. Check it out if you're a Keras fan and you're interested in running code on TPUs easily. OK, we covered a lot of really cool stuff. And as you can imagine, for anything like a vLLM on TPUs, MaxText, Tunix, TorchTPU, JAX, you could get almost a degree in this stuff. There's a lot of really cool things to learn. And so we have a really cool upcoming conference. It's called the AI Systems DevLab. And what this is, it's really an opportunity to hear directly from the engineering team. So we have a lot of really good, really amazing deep talks from a lot of software engineers, a lot of product managers, and folks building this software directly. And they will go really deep on all of this stuff. So it's a great place to learn more. And if you're a power user, it's really the best introduction to all of this that you can get. Join us if you'd like. All the talks are going to be recorded, and we'll upload them to our YouTube channel after. Yeah. So thank you very much for your time. We really appreciate it. And here are some links to learn more. [MUSIC PLAYING]

More from Google