ENFR
8news

Tech • IA • Crypto

TodayVideosVideo recapsArticlesTop articlesArchives

Unlocking Low-Level Control: Customizing Keras Training Loops with JAX

GoogleGoogle for DevelopersApril 30, 20266:33
0:00 / 0:00

TL;DR

Keras enables custom deep learning training loops with JAX by overriding core methods while preserving high-level features like model.fit, callbacks, and distribution support.

Key Points

Custom Training with Keras Fit

Keras allows developers to override the train_step method to implement custom learning algorithms while still using the high-level model.fit API. This approach maintains access to built-in conveniences such as callbacks, metrics, and distributed training, avoiding the need to abandon the framework’s abstractions.

Progressive Disclosure of Complexity

The design follows a principle known as progressive disclosure, enabling users to gradually move from simple workflows to more advanced control. Developers can start with standard training loops and incrementally introduce custom logic without rewriting the entire training pipeline.

JAX Backend Requirements

Custom training with JAX requires configuring Keras to use the JAX backend before importing the library. JAX emphasizes stateless computation, meaning all model components—trainable weights, non-trainable variables, optimizer state, and metrics—must be explicitly passed into and returned from functions.

Stateless Training Step Design

In a JAX-based setup, the train_step function operates entirely on a state tuple. This tuple includes all relevant variables and is updated and returned after each batch. Stateless versions of model operations, such as call and loss computation, are used to ensure compatibility with JAX’s functional paradigm.

Loss and Gradient Computation

A helper function, often structured to compute both loss and auxiliary updates, performs the forward pass and loss calculation. Gradients are then derived using JAX transformations such as value_and_grad, which simultaneously computes the loss value and its gradients, improving efficiency and reducing redundant code.

Handling Auxiliary Outputs

The use of has_aux=True in gradient computation allows functions to return both differentiable outputs, such as loss, and non-differentiable auxiliary data, including updated non-trainable variables. This ensures that only relevant components are included in gradient calculations.

Applying Updates with Optimizers

Keras provides stateless optimizer methods for JAX workflows. The optimizer’s stateless_apply function updates both trainable variables and optimizer state in a functional manner, aligning with JAX’s requirement to avoid in-place mutations.

Metric Tracking in Stateless Mode

Metrics are updated using stateless methods as well, ensuring their internal variables are included in the state tuple. This enables accurate tracking of training performance without breaking the functional structure required by JAX.

Custom Evaluation Logic

Similar customization is possible for evaluation by overriding the test_step method. This process reuses the loss computation logic but skips weight updates, focusing instead on calculating and recording evaluation metrics.

CONCLUSION

By combining JAX’s functional programming model with Keras’ extensible design, developers gain precise control over training logic while retaining the productivity benefits of high-level APIs.

Full transcript

When working with deep learning models, Keras provides a convenient model.fit API for training. But what if your algorithm requires a custom training procedure and you still want to benefit from high level features like callbacks and built in distribution support. Well, Keras follows an important core principle the progressive disclosure of complexity. This means you should be able to gain more low level control without having to abandon all that high level convenience. Hi there, my name is Yufeng and today we will look at how to customize the Keras training loop with the JAX backend by overriding the training step to customize what fit does. One approach is to override the trainstep function of the model class. This is the function that is called by fit for every batch of data, allowing you to run your own learning algorithm while still using fit. As usual. This approach works for sequential models. Functional API models or subclass models. Since this customization specifically targets JAX. Optimization, we must ensure our environment is correctly configured to use the JAX backend. Before importing Keras. Customizing the training step in JAX requires understanding JAX's reliance on stateless computation. The entire trainstep method must be fully stateless. What does this mean. Well, it means that all elements of the model state the trainable variables, the non-trainable variables, optimizer variables, and the metrics variables are all explicitly passed as inputs within a state tuple, and then returned as updated versions of those same variables. We'll use stateless versions of the call apply and compute loss functions, so keep an eye out for how their inputs and outputs end up being structured. A key part of the JAKs implementation is calculating the gradients, typically using a helper function. In this example, we made one called compute loss and updates. As the name suggests, it computes the loss and the updates portion refers to updating those non-trainable variables. So two things are happening in this function. We first do a forward pass of the model using model call, passing the trainable variables and non-trainable variables explicitly as inputs. Using stateless call is important with the JAX backend, and this method will return the predictions y pred, and any updated non-trainable variables as well. Then we'll use these values to compute the loss by passing in the expected y values along with those predicted y values. Now let's look at the train step where we call our helper function compute loss and updates. The plan is to use JAX's grad function transformation, which will produce a new function that computes the gradient of our helper function. This means we'll end up with a function that computes the gradient of the loss, which is perfect for our machine learning needs. In practice, we want to compute the value of the loss as well as its gradient, so we'll do that to our helper function using JAX value and grad. This is simply combining the operations of calling our helper function itself and creating its gradient function. When we compare doing this manually with separate calls to the function and grad, we'll see they're very similar, except you'll notice we didn't have to write our list of arguments to compute loss and updates twice, which is nice. So you might be wondering also why we use the has x equals true argument here. This argument indicates that the function being differentiated returns a pair where the first element is the value to be differentiated, and the second is auxiliary data that should not be differentiated. Our loss computation function returns exactly this the loss, which we want to differentiate, followed by some other stuff that we're kind of just passing through the widespread and those updated non-trainable variables inside of trainstep. After computing the gradients, we update the variables. Keras provides stateless methods for this to. To apply the gradients, we use the optimizers stateless apply method. It's aptly named optimizer.zero grad apply, which returns the updated trainable variables along with the updated optimizer variables. If you are using Keras metrics, you can handle them by calculating the metric result inside the step function using metrics stateless update state. Ensuring these new metric variables are also returned in the state tuple. This same capability for low level control can be applied to evaluation. So to customize how model.evaluate operates, override the test set method with a call to compute loss and updates, just like we did in trainstep. Except we don't need to perform the stateless apply to update the weights this time, since we're just evaluating and we can move directly to recording our loss metrics for evaluation. Scoring by subclassing Keras model and overriding methods like trainstep and test step, you gain fine grained control over the training and evaluation algorithms using JAX's powerful function tools, all while preserving the convenient high level features provided by model.fit. So what custom algorithms have you integrated into Keras model. Share your thoughts in the comments below. And for the full documentation and complete code examples, check out the guide LinkedIn the description below. Thanks for watching and I'll catch you in the next one.

More from Google