Offline

Let it rip a diffusion tutorial

Track:
Machine Learning, NLP and CV
Type:
Tutorial
Level:
intermediate
Duration:
180 minutes
View in the schedule

Abstract

Implementing high-performance deep learning models often feels like a struggle between readable Python code and the low-level optimizations required for modern GPUs and TPUs. JAX bridges this gap by treating neural networks as pure mathematical transformations. In this session, we will move beyond the abstractions of high-level frameworks to build a Denoising Diffusion Probabilistic Model (DDPM) from the ground up.

We will explore how JAX’s functional programming paradigm is uniquely suited for the stochastic nature of diffusion. You will learn how to:

  • Master the JIT (Just-In-Time) compilation: See how @jax.jit transforms Python functions into optimized XLA kernels for massive speedups.
  • Leverage Vectorized Mapping: Use @jax.vmap to handle data parallelism across batches without the overhead of manual loops.
  • Dissect the Diffusion Pipeline: Step through the forward noise process (SDEs) and the reverse denoising process (Score-matching).
  • Manage State and PRNGs: Navigate JAX’s unique, explicit handling of random number generation and stateless transformations.

This tutorial is designed for Python developers and ML engineers who want to understand the "how" and "why" behind state-of-the-art text-to-image models. You will leave with a deep understanding of the diffusion objective and the practical skills to deploy high-performance model architectures using the JAX ecosystem.