ENFR
8news

Tech • IA • Crypto

TodayVideosVideo recapsArticlesTop articlesArchives

Introducing Keras Recommenders: state-of-the-art recommendation techniques at your fingertips

GoogleGoogle for DevelopersApril 28, 20266:35
0:00 / 0:00

TL;DR

The recent integration of Keras with the modular Flax and NNX system opens new possibilities, combining Keras’s simplicity with the power and flexibility of JAX for advanced variable management and custom training.

Key Points

Interoperability between Keras, Flax, and NNX

The new linkage makes keras.Variable an instance of nnx.Variable, enabling seamless coexistence of state between Keras and NNX. This compatibility makes it easy to mix components from both ecosystems without breaking variable management.

Activation via environment variables

To use this integration, two environment variables must be set before importing Keras: define the backend as JAX and explicitly enable NNX mode. This ensures Keras leverages the advanced capabilities of NNX and Flax.

Explicit variable management and tracking

The system allows verification that Keras variables are properly recognized and tracked by NNX: they appear in the managed variable list, have a trace state for just-in-time compilation, and can be accessed directly via NNX, improving control during training.

Modularity and simplified model definition

NNX provides a modular approach using standard Python classes, bringing native modularity into Keras through integration. For example, an NNX module can include linear layers, custom Keras variables, and manage their interaction within the model’s call method.

Two flexible training modes

The first mode follows the classic Keras approach: model.compile and model.fit work as usual, with NNX and JAX operating behind the scenes to optimize performance, without requiring code changes.

Custom training for greater control

The second mode enables custom training loops by treating a Keras model as an nnx.Module. This gives access to the broader JAX ecosystem, including Optax for optimizers, and nnx.jit and nnx.grad for efficient acceleration and differentiation.

Just-in-time compilation and acceleration

Using the nnx.jit decorator improves performance by compiling training functions on the fly, providing noticeable speed gains even for Keras models integrated with NNX.

Unifying rapid prototyping and advanced research

This synergy allows users to start easily with familiar Keras APIs while progressively accessing advanced research and optimization features from JAX, creating a unified framework for all skill levels.

Access to JAX libraries and extensibility

Integrated Keras models can fully leverage JAX tools like Optax, along with a wide range of statistical and differential functions, greatly extending possibilities beyond the traditional Keras framework.

Official resources available

A comprehensive guide and code examples are available on the keras.io website, offering a clear starting point to learn and experiment with this integration.

CONCLUSION

The integration of Keras with Flax and NNX transforms state management and training possibilities, combining ease of use with powerful customization in the JAX ecosystem. This bridge enables more flexible and high-performance machine learning development.

Full transcript

Keras provides user-friendliness and accessibility, while JAX offers high-performance numerical computation. We can now leverage the strengths of both, especially for detailed state management and advanced training capabilities within the JAX ecosystem. All thanks to a new integration of Keras with Flax and NNX. Hi there, my name is Yufeng, and today we're going to check out how to use Keras with the Flax and NNX module system, demonstrating how this integration enhances your variable handling and training control. Keras is highly valued for its high-level API and intuitiveness, making deep learning development straightforward. JAX is excellent for high-performance machine learning research and scalability. NNX is a modular neural network library designed for simplicity and power built on top of JAX. It promotes ease of use through standard Python classes for modules and offers explicit state management via typed variable collections. The integration of Keras with NNX allows you to use the modularity of Keras for model construction, while benefiting from the power and explicit control of NNX and JAX for variable management and advanced training loops. So, to activate this feature, we must first set two environment variables before importing Keras. This enables NNX as an opt-in feature. We set the backend to JAX and then explicitly enable NNX mode by setting the Keras NNX enabled environment variable to true. The core of this integration lies in the keras.Variable, which is designed to be an instance of nnx.Variable from the Flax and NNX ecosystem. This means you can freely mix Keras and NNX components, and NNX's state management tools will successfully track your Keras variables. Here's an example of what that looks like. We've got an NNX module that has a linear layer we've called linear, and a vector value I've named NNX variable, so we can easily keep track of things. I've also added a Keras variable as part of the model called custom variable. We can see in the call function we're adding the NNX variable and the custom variable to the results of the linear transform being applied. Once we have the model instantiated, there are a couple of tests we can run to really verify that the custom variable is set up just as we'd expect. First, we check to see that the Keras variable has what's called a trace state, meaning that NNX has successfully traced through this variable, allowing it to just-in-time compile it along with the rest of the model. We can do this by confirming that it has the attribute trace state. Second, we want to make sure that NNX is counting this variable among all the variables it's aware of using nnx.variables. This shows all the variables that NNX is tracking, and indeed, we do see that our custom variable is listed. Third and finally, let's confirm that the variable's value can be accessed directly by NNX, even though it's a Keras variable inside of an NNX model, the NNX model has no problem fetching its value. Hopefully, I've convinced you by now that keras.Variable is successfully integrated with NNX, allowing Keras state and NNX state to coexist seamlessly. This integration provides two powerful training workflows. The first one is going to feel just like classic Keras, but it runs NNX modules inside of Keras, letting Keras manage the training workflow. The other approach uses NNX to run the training workflow with Keras models inside of NNX training loops. So, in this first version, your existing high-level Keras code, including model.compile and model.fit, it all works out of the box. And under the hood, this productive experience is powered by JAX and NNX. Here we have the other path. For maximum flexibility and fine-grained control, you can treat any Keras model or layer as an nnx.Module. This allows you to write your own custom training loop using JAX libraries, such as Optax for optimizers, while mix and matching the model's components just as we saw in our very first example with custom variable. You can think of this as Keras inside of NNX. Here we see an example of a Keras model with a couple of dense layers, and once that model is created, the rest of the workflow is entirely NNX and JAX code. We'll select an optimizer, we'll write a custom train step to compute the loss, the gradients, and perform the updates to the model weights. Notice that we're using the decorator nnx.jit instead of jax.jit. This special decorator speeds up your NNX code by using just-in-time compilation, and our Keras model gets to benefit from it, too. In short, a Keras model object gets to do everything that an NNX model can do. It's able to be passed seamlessly to NNX optimizer, differentiated using nnx.grad, and used with the broader JAX ecosystem of libraries. The Keras NNX integration offers a significant step forward, providing a unified framework for both rapid prototyping and high-performance customizable research. You can leverage the entire JAX ecosystem, including nnx.jit and libraries like Optax, while still using familiar Keras APIs like model.fit and model.save. The code shown today was an adapted sample of a complete guide on the keras.io website. So, if you're ready to dive in and try out Keras with NNX for yourself, definitely go through that guide first. It's the perfect starting point to get hands-on. So, what are you going to be building with Keras along with the NNX backend? Share your thoughts in the comments below. Remember, for the complete guide and code examples, hit up that link in the description. Thanks for watching, and I'll catch you in the next one.

More from Google