Graph execution is a computation paradigm in machine learning frameworks where mathematical operations are organized into a directed acyclic graph (DAG) before being executed. In this model, nodes represent operations (such as matrix multiplication or activation functions) and edges represent the tensors or data flowing between them. Rather than evaluating each operation immediately as it is encountered in code, graph execution constructs the full computation graph first and then runs it as a whole, enabling the runtime to apply optimizations that would not be possible with operation-by-operation evaluation.
Graph execution became widely known through TensorFlow 1.x, which required users to define a static graph and then execute it within a session. The approach has since evolved significantly. Modern frameworks offer mechanisms like tf.function in TensorFlow 2.x, torch.compile in PyTorch 2.0, and jax.jit in JAX that let developers write code in an imperative style while still benefiting from graph-level optimizations at execution time.
Imagine you want to bake a cake. You could do it step by step: get the flour, then get the eggs, then get the sugar, then mix them, then bake. Each time you finish one step, you figure out what to do next.
Graph execution is like writing out the entire recipe before you start cooking. You draw a picture showing all the steps and how they connect: "flour and eggs go into the mixing bowl, the mixed batter goes into the oven." Because you can see the whole recipe at once, you can figure out smart shortcuts. Maybe you can mix the dry ingredients while someone else cracks the eggs, because those two steps do not depend on each other. Or maybe you notice that two steps could be combined into one.
That is what computers do with graph execution. They look at the whole plan of math operations, find shortcuts, and run things in parallel whenever possible. This makes everything faster, especially when using powerful hardware like GPUs or TPUs.
A computational graph is the underlying data structure that makes graph execution possible. It is a directed acyclic graph where each node corresponds to a mathematical operation and each edge carries a tensor (a multi-dimensional array of numbers) from one operation to another.
| Component | Description |
|---|---|
| Nodes (operations) | Represent functions or transformations applied to data, such as matrix multiplication, convolution, or activation functions like ReLU. |
| Edges (tensors) | Represent the data flowing between operations. Each edge carries a tensor whose shape, dtype, and device placement are known. |
| Input nodes | Entry points of the graph where external data (training examples, features) is fed in. |
| Output nodes | Terminal nodes that produce the final result, such as a loss value or a prediction. |
| Variables | Mutable state nodes that persist across executions and are updated during training. These typically represent weights and biases of a neural network. |
| Constants | Immutable values that are known at graph construction time and do not change during execution. |
Computational graphs are central to automatic differentiation, the mechanism that powers backpropagation in deep learning. During the forward pass, the framework records every operation in the graph. During the backward pass, it traverses the graph in reverse, applying the chain rule at each node to compute gradients with respect to each parameter. Because the graph encodes all the dependencies between operations, the framework can efficiently determine which partial derivatives are needed and in what order they should be computed.
Deep learning frameworks can be classified by whether they construct computational graphs statically (before execution) or dynamically (during execution). This distinction, sometimes described as "define-and-run" versus "define-by-run," has shaped the design of every major ML framework.
In the static graph model, the entire computation graph is defined before any data flows through it. The framework first builds an internal representation of all operations and their connections, and then executes the graph in a separate step. TensorFlow 1.x is the most prominent example of this approach.
With static graphs, the framework has full visibility into the computation before execution begins, which means it can perform whole-program optimizations such as operator fusion, constant folding, and memory planning. Static graphs can also be serialized and deployed without the original source code, making them well suited for production deployment on mobile devices or embedded systems.
However, static graphs impose constraints on the programming model. Control flow that depends on runtime data (such as variable-length sequences in recurrent neural networks) requires special constructs like tf.cond and tf.while_loop rather than standard Python if and for statements. Debugging is also more difficult because errors are reported at graph execution time, not at the point where the operation was defined in the source code.
In the dynamic graph model, the computation graph is constructed on the fly as operations are executed. Each forward pass through the network builds a new graph, which is then used for backpropagation and discarded. PyTorch popularized this approach (following earlier work by Chainer and DyNet), and it became the default for research workflows because of its flexibility and ease of debugging.
Dynamic graphs allow unrestricted use of Python control flow, making it straightforward to implement models with data-dependent architectures, such as tree-structured networks or models with variable computation depth. Standard Python debugging tools work as expected, and errors are raised immediately at the line that caused them.
The tradeoff is that the framework sees only one operation at a time, limiting the scope of optimization. Techniques like operator fusion and cross-operation memory planning are harder to apply without a global view of the computation.
| Feature | Static graphs | Dynamic graphs |
|---|---|---|
| Construction | Built before execution | Built during execution |
| Terminology | Define-and-run | Define-by-run |
| Control flow | Requires special graph constructs | Uses standard Python control flow |
| Debugging | Errors at graph runtime; harder to trace | Errors at point of definition; standard debuggers work |
| Optimization scope | Whole-program; can fuse, fold, reorder globally | Limited to individual operations or small subgraphs |
| Serialization | Easy; graph can be saved and deployed without source code | Requires additional tooling (e.g., TorchScript, torch.export) |
| Flexibility | Less flexible for dynamic architectures | Naturally supports variable-length inputs, conditional computation |
| Historical example | TensorFlow 1.x, Theano, Caffe | PyTorch, Chainer, DyNet |
The strict division between static and dynamic graphs has blurred in recent years. TensorFlow 2.x adopted eager execution as the default mode while providing tf.function to convert Python functions into optimized graphs on demand. PyTorch introduced TorchScript in version 1.0 and torch.compile in version 2.0 to capture dynamic Python code into optimizable graph representations. JAX takes a functional approach where jax.jit traces pure functions into a graph intermediate representation (jaxpr) that is then compiled by XLA. The trend across all major frameworks is to let developers write code in a natural, imperative style and then apply graph-level optimizations transparently.
TensorFlow has the most extensive history with graph execution, having been built around it from the very beginning.
In TensorFlow 1.x, graph execution was the only execution mode. Users constructed a tf.Graph by calling TensorFlow API functions (like tf.matmul, tf.add, tf.placeholder), which added operation nodes to the default graph rather than executing them immediately. To actually run the computation, users created a tf.Session and called session.run(), passing in the desired output tensors and a feed dictionary mapping placeholder nodes to input data.
This design was powerful but verbose. A common pattern was the "kitchen sink" approach, where all possible computations were laid out in a single graph and different subsets were executed depending on whether the code was training, evaluating, or running inference. The separation between graph construction and graph execution made code harder to read and debug compared to standard Python.
TensorFlow 2.x made eager execution the default mode, meaning operations execute immediately and return concrete values. To recover the performance benefits of graph execution, TensorFlow 2.x introduced the @tf.function decorator. When a Python function decorated with @tf.function is called, TensorFlow performs a two-stage process:
tf.Graph. Python code runs normally during tracing, but TensorFlow operations are deferred and captured in the graph.Subsequent calls to the same tf.function with compatible input signatures reuse the previously traced graph without re-running the Python code. If the function is called with a new input shape or dtype, TensorFlow traces a new graph for that signature. The object returned by tf.function is called a PolymorphicFunction because it can hold multiple ConcreteFunction objects, each corresponding to a different input signature.
AutoGraph is a companion library (enabled by default inside tf.function) that converts Python control flow statements into their TensorFlow graph equivalents. For example, Python if statements that depend on tensor values are converted to tf.cond, and Python while loops become tf.while_loop. This allows developers to write natural Python code while still generating valid graph representations.
The TensorFlow documentation provides a benchmark computing repeated matrix multiplication. Computing the 100th power of a 10x10 matrix 1,000 times took approximately 4.10 seconds in eager mode and 0.80 seconds with tf.function, a roughly 5x speedup [1].
Before executing a graph, TensorFlow runs it through Grappler, its default graph optimization framework. Grappler applies a series of transformation passes including constant folding, arithmetic simplification, layout optimization, and operation fusion. These passes are applied automatically and can be configured or disabled through the tf.config.optimizer API.
PyTorch was originally designed around dynamic (eager) execution, but the framework has progressively added graph capture and compilation capabilities.
TorchScript, introduced in PyTorch 1.0, provides two mechanisms for converting PyTorch models into a graph-based intermediate representation:
torch.jit.trace): Runs the model with example inputs and records the sequence of operations performed. The resulting trace captures a static graph of the computation. Tracing works well for models that do not contain data-dependent control flow, but it silently produces incorrect results for models with conditional branches or loops that depend on tensor values, because only the code path taken during the tracing run is recorded.torch.jit.script): Analyzes the Python source code of the model and compiles it into TorchScript IR, which does support control flow. Scripting handles a subset of Python and may require code modifications to work correctly.TorchScript models can be serialized and loaded in C++ without a Python runtime, making them suitable for production deployment.
PyTorch 2.0 introduced torch.compile, a new compiler-based approach to graph execution that supersedes TorchScript for most use cases. The torch.compile stack consists of three components:
| Component | Role |
|---|---|
| TorchDynamo | A Python bytecode interpreter that intercepts execution, captures sequences of PyTorch operations into FX graphs, and identifies "graph breaks" where unsupported Python features force a switch back to eager execution. |
| AOTAutograd | Generates forward and backward graphs ahead of time, enabling the compiler backend to optimize both training and inference paths. |
| TorchInductor | The default compiler backend that takes FX graphs and generates optimized code. For GPUs, it produces Triton kernels; for CPUs, it generates C++ code with OpenMP parallelism. |
A distinguishing feature of TorchDynamo is its handling of unsupported code. When it encounters Python constructs it cannot convert to graph operations, it inserts a "graph break," compiles the graph captured so far, runs the unsupported code in eager mode, and then resumes graph capture. This allows torch.compile to work on virtually any Python code, even if parts of it cannot be represented as a graph. According to PyTorch documentation, TorchDynamo successfully captures graphs from approximately 99% of code in practice, compared to roughly 50% for earlier graph capture methods [2].
JAX takes a functional programming approach to graph execution. The jax.jit transformation converts pure Python functions into compiled computations by tracing them into an intermediate representation called jaxpr (JAX expression).
When a function decorated with @jax.jit is called for the first time, JAX executes the function with special tracer objects instead of real numerical values. These tracers record every JAX operation into a jaxpr, which is a typed, functional, first-order intermediate representation in algebraic normal form. The jaxpr is then lowered to XLA's High Level Optimizer (HLO) representation and compiled into optimized machine code for the target hardware (CPU, GPU, or TPU).
JAX exposes three explicit compilation stages that users can inspect:
| Stage | Method | Description |
|---|---|---|
| Traced | .trace() | Produces the jaxpr intermediate representation, showing the sequence of primitive operations. |
| Lowered | .lower() | Converts the jaxpr to XLA HLO, the input format for the XLA compiler. |
| Compiled | .compile() | Produces the final compiled executable optimized for the target hardware. |
Because JAX functions are expected to be pure (free of side effects), the compiler has strong guarantees about what optimizations are safe to apply. This functional discipline makes it straightforward to compose jax.jit with other JAX transformations such as jax.grad (automatic differentiation), jax.vmap (automatic vectorization), and jax.pmap (parallel execution across devices).
XLA (Accelerated Linear Algebra) is an open-source compiler for machine learning that sits at the center of graph execution for both TensorFlow and JAX (and is increasingly used by PyTorch via PyTorch/XLA). XLA takes a high-level graph representation and compiles it into optimized machine code for specific hardware targets.
The XLA compilation process follows these steps:
Operation fusion is XLA's most impactful optimization. Memory bandwidth is often the scarcest resource on hardware accelerators, so reducing memory transfers by combining multiple operations into a single kernel launch produces large performance gains. For example, an element-wise addition followed by a multiplication and a reduction can be fused into a single GPU kernel instead of three separate kernel launches with intermediate memory reads and writes.
XLA also performs constant folding (computing results of constant expressions at compile time), buffer reuse analysis (sharing memory between tensors whose lifetimes do not overlap), and target-specific code generation (using architecture-specific instructions and memory layouts).
In TensorFlow, XLA can be enabled on a per-function basis using tf.function(jit_compile=True). In JAX, XLA is the default compilation backend for all jax.jit-compiled functions.
Graph-level optimizations are applied by framework runtimes and compilers to improve the performance of computational graphs before or during execution. These optimizations fall into several categories.
Operator fusion (also called kernel fusion) combines multiple adjacent operations in the graph into a single fused operation. This reduces the number of kernel launches on GPUs and eliminates intermediate memory allocations. For example, a batch normalization layer followed by a ReLU activation can be fused into a single kernel that reads the input once, computes both operations, and writes the output once.
Constant folding identifies subgraphs where all inputs are known at compile time and evaluates them during compilation rather than at runtime. The result replaces the original subgraph with a single constant node. This is particularly effective for operations involving model hyperparameters or fixed preprocessing steps.
Dead code elimination removes nodes in the graph whose outputs are never consumed by any downstream operation. The process works similarly to garbage collection: starting from the output nodes, the optimizer traces backward through the graph, marks every reachable node as "live," and removes any node not marked.
Common subexpression elimination (CSE) identifies cases where the same operation with the same inputs appears multiple times in the graph. Instead of computing the result multiple times, the optimizer computes it once and reuses the result wherever it is needed.
Memory optimization techniques include liveness analysis (determining which tensors are needed at each point in execution and freeing those that are no longer required), in-place operations (modifying tensors in place rather than creating copies when safe to do so), and memory layout transformations (rearranging data in memory to match hardware preferences, such as converting from NCHW to NCHWc format on CPUs).
| Technique | What it does | Benefit |
|---|---|---|
| Operator fusion | Combines adjacent operations into a single kernel | Reduces kernel launch overhead and memory transfers |
| Constant folding | Evaluates constant subgraphs at compile time | Eliminates unnecessary runtime computation |
| Dead code elimination | Removes unreachable or unused operations | Reduces graph size and execution time |
| Common subexpression elimination | Reuses results of duplicate computations | Avoids redundant work |
| Layout optimization | Transforms data layout to match hardware preferences | Improves cache utilization and memory access patterns |
| Memory planning | Allocates and reuses buffers based on tensor lifetimes | Reduces peak memory consumption |
ONNX (Open Neural Network Exchange) is an open standard that defines a common format for representing machine learning models as computational graphs. ONNX enables models trained in one framework to be deployed using another framework's runtime, breaking down the barriers between different ML ecosystems.
An ONNX model consists of a computational graph where nodes represent operators (convolution, pooling, matrix multiplication, etc.) and edges represent the tensors flowing between them. Each node includes attributes that define its behavior (such as kernel size for a convolution or axis for a reduction). The ONNX specification defines a standard set of over 150 operators, and frameworks implement converters that translate their native graph representations into ONNX format.
PyTorch provides torch.onnx.export for converting models to ONNX, while TensorFlow models can be converted using the tf2onnx tool. Scikit-learn models (which do not use computational graphs internally) can be converted using sklearn-onnx.
ONNX Runtime is Microsoft's high-performance inference engine for ONNX models. It applies its own graph optimizations organized into three levels:
| Level | Scope | Example optimizations |
|---|---|---|
| Basic | Applied before graph partitioning; provider-independent | Constant folding, redundant node elimination, Conv+BatchNorm fusion, Conv+Add fusion |
| Extended | Applied after partitioning; provider-specific | GEMM activation fusion, attention fusion, layer normalization fusion, BERT embedding layer fusion |
| Layout | CPU-specific transformations | NCHW to NCHWc format conversion for improved CPU performance |
ONNX Runtime uses a system of execution providers to dispatch graph nodes to specialized hardware backends. When a model is loaded, the runtime queries each available execution provider (such as CUDA, TensorRT, OpenVINO, or DirectML) to determine which nodes it can handle. The graph is then partitioned into subgraphs, with each subgraph assigned to the most appropriate execution provider. Nodes not supported by any accelerator fall back to the default CPU execution provider [3].
ONNX primarily targets static computational graphs. Models with highly dynamic structures, such as those with data-dependent control flow or variable computation paths, may not convert cleanly to ONNX format. The conversion process can also introduce subtle numerical differences or lose framework-specific optimizations. Despite these limitations, ONNX has become a widely adopted standard for model interoperability and deployment.
Because the framework can inspect the entire computation before executing it, graph execution enables optimizations that are impossible in pure eager mode. The runtime can fuse operations, eliminate redundant computations, optimize memory allocation, and schedule work across multiple devices. For large models running on accelerators, these optimizations can produce significant speedups.
Graph execution enables better utilization of parallel hardware. By analyzing the dependency structure of the graph, the runtime can identify independent operations that can execute simultaneously on different cores, stream processors, or devices. This is particularly valuable on GPUs, which have thousands of cores that benefit from large-scale parallelism, and on TPUs, which are designed specifically for graph-based execution of tensor operations.
A serialized computational graph can be loaded and executed without the original Python source code. This makes graph execution well suited for deploying models to production environments, mobile devices, embedded systems, and web browsers. TensorFlow's SavedModel format, PyTorch's TorchScript and ExportedProgram formats, and ONNX all support this pattern.
With a complete view of the computation graph, the runtime can plan memory allocation in advance. It can determine which tensors are needed at each stage of execution, share memory buffers between tensors whose lifetimes do not overlap, and schedule operations to minimize peak memory usage. This is especially important when training large models where GPU memory is a limiting factor.
Graph execution introduces a layer of abstraction between the code the developer writes and the operations that actually execute. Errors may surface during graph execution rather than at the point where the problematic operation was defined in the source code. While tools like TensorFlow's tf.debugging module and PyTorch's graph break reports help, debugging graph-compiled code remains more difficult than debugging eager code.
Graph tracing captures the operations performed during a single execution of the function, which means that Python side effects (print statements, mutations to Python data structures, random number generation using Python's random module) behave differently in graph mode than in eager mode. Developers must be aware that traced functions may not re-execute Python code on subsequent calls, only the recorded graph operations.
Models with highly dynamic computation patterns, such as those where the graph structure changes based on the input data at every step, may not benefit from graph execution. The overhead of repeatedly tracing and compiling new graphs can outweigh the performance gains from optimization. In such cases, eager execution or hybrid approaches (where only the stable parts of the computation are compiled) may be more appropriate.
The first call to a graph-compiled function includes the cost of tracing and compiling the graph, which can be significant for complex models. For models that are called only a few times, this overhead may not be recouped. Caching mechanisms (such as TensorFlow's ConcreteFunction reuse or JAX's compilation cache) mitigate this cost for functions called repeatedly with the same input shapes.
| Feature | TensorFlow 2.x | PyTorch 2.x | JAX |
|---|---|---|---|
| Graph capture mechanism | @tf.function with AutoGraph | torch.compile with TorchDynamo | @jax.jit with tracing |
| Intermediate representation | tf.Graph / FunctionDef | FX Graph (ATen IR) | jaxpr |
| Compiler backend | Grappler, XLA (optional) | TorchInductor (default), XLA (via PyTorch/XLA) | XLA (default) |
| Control flow handling | AutoGraph converts Python control flow to graph ops | Graph breaks; falls back to eager for unsupported constructs | Requires jax.lax.cond, jax.lax.scan for traced control flow |
| Dynamic shapes | Retraces for new shapes; input_signature can fix shapes | Supports dynamic shapes with guards | Retraces for new shapes; static_argnums can fix arguments |
| Serialization format | SavedModel, TFLite | ExportedProgram, TorchScript | StableHLO (experimental) |
| Default execution mode | Eager (graph via @tf.function) | Eager (graph via torch.compile) | Eager (graph via @jax.jit) |
| Legacy graph mode | TF 1.x Session-based execution | TorchScript (torch.jit.trace, torch.jit.script) | N/A (JAX has always used JIT tracing) |
Visualizing computational graphs helps developers understand model architecture, identify bottlenecks, and verify that optimizations are applied correctly.
tf.Graph objects, showing operations, data flow, and device placements.print(jaxpr): JAX allows direct inspection of the jaxpr intermediate representation by calling .trace() on a jit-compiled function and printing the result.graph.print_tabular() and graph.python_code() methods for inspecting captured graphs in torch.compile.