ENFR
8news

Tech • IA • Crypto

Aujourd'huiVidéosRécaps vidéoArticlesTop articlesArchives

Introducing Keras Recommenders: state-of-the-art recommendation techniques at your fingertips

GoogleGoogle for Developers28 avril 20266:35
0:00 / 0:00

INTRO

L’intégration récente de Keras avec le système modulaire Flax et NNX ouvre de nouvelles perspectives, combinant la simplicité de Keras avec la puissance et la flexibilité de JAX pour la gestion avancée des variables et les entraînements personnalisés.

Points clés

Interopérabilité entre Keras, Flax et NNX

La nouvelle liaison fait de keras.Variable une instance de nnx.Variable, ce qui permet une coexistence harmonieuse des états entre Keras et NNX. Cette compatibilité facilite le mélange des composants issus des deux écosystèmes sans rupture dans la gestion des variables.

Activation par variables d’environnement

Pour bénéficier de cette intégration, il faut configurer deux variables d’environnement avant d’importer Keras: définir le backend sur JAX et activer explicitement le mode NNX. Cette étape assure que Keras utilise les capacités avancées de NNX et Flax.

Gestion explicite et traçage des variables

Le système permet de vérifier que les variables Keras sont bien reconnues et suivies par NNX: elles apparaissent dans la liste des variables gérées, disposent d’un état de trace pour la compilation juste-à-temps et peuvent être accédées directement via NNX, renforçant le contrôle lors des phases d’entraînement.

Modularité et définition simplifiée des modèles

NNX offre une approche modulaire avec des classes Python standard, introduisant une modularité native dans Keras grâce à l’intégration. Par exemple, un module NNX peut contenir des couches linéaires, des variables personnalisées Keras et manipuler leur interaction dans la méthode d’appel du modèle.

Deux modes de formation flexibles

Le premier mode correspond à la méthode classique de Keras: model.compile et model.fit fonctionnent normalement, avec NNX et JAX en coulisses pour optimiser les performances, sans nécessiter de refonte du code.

Entraînement personnalisé pour plus de contrôle

Le second mode permet d’exécuter des boucles d’entraînement personnalisées en traitant un modèle Keras comme un nnx.Module. Cela donne accès au vaste écosystème JAX, y compris Optax pour les optimisateurs, et aux fonctions nnx.jit et nnx.grad pour accélérer et différencier efficacement les modèles.

Just-in-time compilation et accélération

L’utilisation du décorateur nnx.jit améliore les performances en compilant à la volée les fonctions d’entraînement, ce qui donne un gain notable en vitesse même sur des modèles Keras intégrés à NNX.

Unification de prototypage rapide et recherche avancée

Cette symbiose permet de débuter facilement avec les API familières de Keras tout en accédant progressivement à des fonctionnalités avancées de recherche et optimisation propres à JAX, créant ainsi un cadre unifié pour tous les niveaux d’expertise.

Accès aux bibliothèques JAX et extensibilité

Les modèles Keras ainsi intégrés peuvent exploiter pleinement les outils JAX comme Optax, un large éventail de fonctions statistiques et différentielles, étendant considérablement les possibilités au-delà du cadre Keras traditionnel.

Ressources officielles disponibles

Un guide complet et des exemples de code sont proposés sur le site keras.io, offrant un point de départ clair pour maîtriser cette nouvelle approche et expérimenter concrètement cette intégration.

CONCLUSION

L’intégration de Keras avec Flax et NNX révolutionne la gestion d’état et les possibilités d’entraînement, offrant à la fois simplicité d’utilisation et puissance de personnalisation dans l’écosystème JAX. Ce pont ouvre la voie à des développements plus flexibles et performants en machine learning.

Transcription complète

Keras provides user-friendliness and accessibility, while JAX offers high-performance numerical computation. We can now leverage the strengths of both, especially for detailed state management and advanced training capabilities within the JAX ecosystem. All thanks to a new integration of Keras with Flax and NNX. Hi there, my name is Yufeng, and today we're going to check out how to use Keras with the Flax and NNX module system, demonstrating how this integration enhances your variable handling and training control. Keras is highly valued for its high-level API and intuitiveness, making deep learning development straightforward. JAX is excellent for high-performance machine learning research and scalability. NNX is a modular neural network library designed for simplicity and power built on top of JAX. It promotes ease of use through standard Python classes for modules and offers explicit state management via typed variable collections. The integration of Keras with NNX allows you to use the modularity of Keras for model construction, while benefiting from the power and explicit control of NNX and JAX for variable management and advanced training loops. So, to activate this feature, we must first set two environment variables before importing Keras. This enables NNX as an opt-in feature. We set the backend to JAX and then explicitly enable NNX mode by setting the Keras NNX enabled environment variable to true. The core of this integration lies in the keras.Variable, which is designed to be an instance of nnx.Variable from the Flax and NNX ecosystem. This means you can freely mix Keras and NNX components, and NNX's state management tools will successfully track your Keras variables. Here's an example of what that looks like. We've got an NNX module that has a linear layer we've called linear, and a vector value I've named NNX variable, so we can easily keep track of things. I've also added a Keras variable as part of the model called custom variable. We can see in the call function we're adding the NNX variable and the custom variable to the results of the linear transform being applied. Once we have the model instantiated, there are a couple of tests we can run to really verify that the custom variable is set up just as we'd expect. First, we check to see that the Keras variable has what's called a trace state, meaning that NNX has successfully traced through this variable, allowing it to just-in-time compile it along with the rest of the model. We can do this by confirming that it has the attribute trace state. Second, we want to make sure that NNX is counting this variable among all the variables it's aware of using nnx.variables. This shows all the variables that NNX is tracking, and indeed, we do see that our custom variable is listed. Third and finally, let's confirm that the variable's value can be accessed directly by NNX, even though it's a Keras variable inside of an NNX model, the NNX model has no problem fetching its value. Hopefully, I've convinced you by now that keras.Variable is successfully integrated with NNX, allowing Keras state and NNX state to coexist seamlessly. This integration provides two powerful training workflows. The first one is going to feel just like classic Keras, but it runs NNX modules inside of Keras, letting Keras manage the training workflow. The other approach uses NNX to run the training workflow with Keras models inside of NNX training loops. So, in this first version, your existing high-level Keras code, including model.compile and model.fit, it all works out of the box. And under the hood, this productive experience is powered by JAX and NNX. Here we have the other path. For maximum flexibility and fine-grained control, you can treat any Keras model or layer as an nnx.Module. This allows you to write your own custom training loop using JAX libraries, such as Optax for optimizers, while mix and matching the model's components just as we saw in our very first example with custom variable. You can think of this as Keras inside of NNX. Here we see an example of a Keras model with a couple of dense layers, and once that model is created, the rest of the workflow is entirely NNX and JAX code. We'll select an optimizer, we'll write a custom train step to compute the loss, the gradients, and perform the updates to the model weights. Notice that we're using the decorator nnx.jit instead of jax.jit. This special decorator speeds up your NNX code by using just-in-time compilation, and our Keras model gets to benefit from it, too. In short, a Keras model object gets to do everything that an NNX model can do. It's able to be passed seamlessly to NNX optimizer, differentiated using nnx.grad, and used with the broader JAX ecosystem of libraries. The Keras NNX integration offers a significant step forward, providing a unified framework for both rapid prototyping and high-performance customizable research. You can leverage the entire JAX ecosystem, including nnx.jit and libraries like Optax, while still using familiar Keras APIs like model.fit and model.save. The code shown today was an adapted sample of a complete guide on the keras.io website. So, if you're ready to dive in and try out Keras with NNX for yourself, definitely go through that guide first. It's the perfect starting point to get hands-on. So, what are you going to be building with Keras along with the NNX backend? Share your thoughts in the comments below. Remember, for the complete guide and code examples, hit up that link in the description. Thanks for watching, and I'll catch you in the next one.

Sur le même sujet : Google