Computational graph
Last reviewed
May 1, 2026
Sources
16 citations
Review status
Source-backed
Revision
v1 · 3,954 words
Improve this article
Add missing citations, update stale details, or suggest a clearer explanation.
Last reviewed
May 1, 2026
Sources
16 citations
Review status
Source-backed
Revision
v1 · 3,954 words
Add missing citations, update stale details, or suggest a clearer explanation.
A computational graph is a directed acyclic graph (DAG) representation of a numerical computation, where nodes represent operations (or variables) and edges represent the data, typically tensors, that flow between them. It is the central abstraction of modern deep-learning frameworks: every neural network forward pass is built as a computational graph, and the backward pass that computes gradients is performed by traversing this graph in reverse to apply the chain rule of calculus, a procedure known as reverse-mode automatic differentiation.
The graph view of computation predates deep learning by several decades, with roots in dataflow programming and in classical algorithmic differentiation literature (Griewank, Wengert). It became the dominant abstraction in machine learning around 2010 with the release of Theano, and was carried forward into TensorFlow, PyTorch, JAX, MXNet, and most other modern frameworks. Understanding computational graphs is foundational to understanding how backpropagation works, why GPU acceleration is possible at the operation level, and how compilers such as XLA, TorchInductor, and TVM optimize neural-network computation.
Formally, a computational graph for a function f is a DAG G = (V, E) in which:
A simple example is the expression z = (x + y) * (x - y). The graph has two input nodes (x and y), two intermediate operation nodes (add and subtract), and one output node (multiply). Evaluating in topological order computes the add and subtract first, then feeds them into the multiply. To differentiate, the framework records this graph and walks it backward to accumulate partial derivatives.
In deep-learning frameworks, the values flowing along edges are tensors: multi-dimensional arrays with a defined shape and dtype. The same DAG structure scales from a two-variable arithmetic example to transformer training graphs containing tens of thousands of nodes.
Computational graphs are the data structure that makes modern deep learning practical. Three properties account for their dominance:
The abstraction is general enough to express not only feedforward networks but also recurrent networks, attention, convolutions, control flow, probabilistic models, and continuous-depth networks (Neural ODEs). Frameworks such as PyMC and Pyro reuse the same machinery for probabilistic programming.
Deep-learning frameworks fall into two broad camps depending on when the graph is constructed.
| Paradigm | Also known as | Construction time | Example frameworks | Strengths | Weaknesses |
|---|---|---|---|---|---|
| Static graph | Define-and-run, graph execution | Graph is built once, then executed many times with different data | Theano, TensorFlow 1.x, Caffe, MXNet (Symbolic), JAX (under jit) | Whole-graph optimization, easier deployment, ahead-of-time compilation, predictable memory, easier serving | Awkward debugging, less Pythonic control flow, longer dev iteration |
| Dynamic graph | Define-by-run, eager execution | Graph is built implicitly as Python code runs | Chainer, PyTorch, DyNet, TensorFlow 2.x default, Autograd | Pythonic, easy to debug with standard tools, natural support for variable-shape inputs and control flow | Harder to apply whole-program optimizations, traditionally slower for repeated identical computations |
The static approach was dominant from roughly 2010 to 2017 because it made it easier to extract performance from immature GPU stacks. The release of Chainer in 2015 introduced the define-by-run idea, which PyTorch adopted in 2017 and made mainstream. By 2019, TensorFlow had switched its 2.x default to eager mode while keeping a path to static graphs through tf.function. Today, most frameworks try to offer both: eager-by-default for development, with a way to compile hot regions into a static graph for production performance.
A computational graph is built and used in two phases.
During the forward pass, the framework computes the value of each node in topological order. In an eager framework like PyTorch, every Python expression that touches a Tensor triggers an actual kernel call and adds a node to the implicit graph. In a static framework like TensorFlow 1.x or under JAX's jit, the same Python code instead constructs a symbolic graph that will be executed later by a runtime or compiled to XLA.
For each node, the framework typically records:
Given a scalar loss L, the backward pass computes dL/dx for every leaf tensor x that requires gradients. It does this by initializing dL/dL = 1 at the output, then walking the graph in reverse topological order. At each node, given the upstream gradient (the partial of L with respect to the node's output), it applies the local Jacobian to produce the gradient with respect to each of the node's inputs. These gradients flow back along the edges and accumulate at parameter nodes through the chain rule.
This is exactly reverse-mode automatic differentiation. In neural networks, where the input dimension (number of parameters) is far larger than the output dimension (a single scalar loss), reverse mode is dramatically more efficient than forward mode, which is why it became the standard.
The choice of differentiation mode determines how the graph is traversed and how gradients are assembled. Baydin, Pearlmutter, Radul, and Siskind (2018) provide the canonical survey.
| Mode | What it computes | Cost (relative to forward) | When efficient | Used by |
|---|---|---|---|---|
| Forward-mode autodiff | Jacobian-vector products (JVP), one column of the Jacobian per pass | About 2 to 3 times the forward cost per JVP | Number of inputs <= number of outputs (tall Jacobian) | JAX (jax.jvp), PyTorch (torch.autograd.functional.jvp) |
| Reverse-mode autodiff (backpropagation) | Vector-Jacobian products (VJP), one row of the Jacobian per pass | About 2 to 4 times the forward cost per VJP, but stores activations | Number of inputs >> number of outputs (wide Jacobian, scalar loss) | All major DL frameworks |
| Symbolic differentiation | Closed-form derivative expression | Variable; can blow up due to expression swell | Small symbolic problems | Mathematica, SymPy, original Theano (partly) |
| Numerical differentiation | Finite differences | One forward eval per partial derivative | Tiny problems, gradient checking only | Often used to validate autodiff implementations |
For a typical neural network with millions of parameters and a single scalar loss, reverse-mode autodiff requires only a single backward pass to compute every gradient, which is why backpropagation has been the workhorse of deep-learning training since the 1980s. Forward mode is occasionally useful for things like Hessian-vector products (when combined with reverse mode) and for differentiating through tall functions where the output is higher-dimensional than the input.
Deep-learning frameworks share a small vocabulary of node types.
| Element | Role | Notes |
|---|---|---|
| Tensor | Multi-dimensional array flowing along an edge | Has shape, dtype, and device; the basic data primitive |
| Operation (op) | Function applied at a node | Examples: matmul, add, conv2d, softmax, layernorm |
| Variable / Parameter | Tensor that requires gradients | Weights, biases; participates in optimizer updates |
| Constant | Tensor with no gradient | Input data in many setups, fixed lookup tables |
| Placeholder | Symbolic input slot (TF 1.x) | Filled at session run time; obsolete in TF 2.x eager mode |
| Saved tensor | Intermediate value cached for backward | Activations stored during forward pass |
| Gradient function | Backward op for each forward op | Encodes the local Jacobian of the operation |
Frameworks differ in how they expose these. PyTorch attaches a grad_fn attribute to every non-leaf tensor that points to the autograd Function used to compute the backward pass. TensorFlow 1.x had explicit tf.placeholder and tf.Variable types. JAX traces pure functions and produces a jaxpr (JAX program representation) that is then handed to XLA.
The practical difference between the two paradigms shows up in the developer workflow.
| Aspect | Static graph workflow | Dynamic graph workflow |
|---|---|---|
| When the graph is built | Once, before any data flows | Implicitly, as each op runs |
| Debugging | Hard; errors surface at session run, not at the offending Python line | Easy; standard Python debuggers work, errors point at the line |
| Variable shapes | Must be known (or symbolic) at graph build time | Can vary per call |
| Control flow | Encoded with framework-specific ops (tf.cond, tf.while_loop) | Native Python if, for, while |
| Optimization opportunities | High; whole-graph view enables operator fusion, layout transforms, scheduling | Limited unless a separate trace/compile step is added |
| Deployment | Easy to serialize and ship the graph | Requires a separate export step (TorchScript, ONNX, torch.compile) |
| Performance per call | Often higher after compilation | Often lower per op, but no compile latency |
Most modern systems blur the line. PyTorch 2.0 introduced torch.compile, which captures eager Python code into a graph and compiles it via TorchInductor. JAX's jax.jit decorator traces a pure Python function once and compiles the resulting jaxpr to XLA, giving define-by-run ergonomics with static-graph performance. TensorFlow's tf.function does the equivalent for TF 2.x.
| Framework | First release | Default mode | Graph approach | Status |
|---|---|---|---|---|
| Theano | 2010 | Static | Symbolic computational graph in Python with C++/CUDA codegen | Discontinued in 2017; spiritual successor in PyMC's Aesara/PyTensor |
| Torch (Lua) | 2002 / DL focus 2011 | Eager | Imperative tensor library with hand-coded backprop modules | Largely replaced by PyTorch |
| Caffe | 2013 | Static | Layer-based prototxt graph | Caffe2 merged into PyTorch in 2018 |
| TensorFlow 1.x | 2015 | Static | tf.placeholder + tf.Session.run; explicit graphs | Superseded by 2.x |
| Chainer | 2015 | Dynamic | Pioneered define-by-run; influenced PyTorch | Maintenance mode since 2019 |
| MXNet | 2015 | Both | Symbolic API plus Gluon imperative API; hybridization for production | Apache repo archived November 2023 |
| PyTorch | 2017 | Dynamic | Autograd tape records ops as they run | Active; default research framework |
| DyNet | 2017 | Dynamic | Dynamic graph framework predating PyTorch | Maintained but niche |
| TensorFlow 2.x | 2019 | Eager | Eager by default, tf.function traces to static graphs | Active |
| JAX | 2018 | Functional | grad and jit transform pure Python functions; produces jaxpr for XLA | Active; growing in research |
| PyTorch 2.0 | March 2023 | Dynamic + compile | TorchDynamo captures graphs from eager Python; AOTAutograd handles backward; TorchInductor lowers to Triton/C++ | Active |
| Flux.jl | 2017 | Dynamic | Differentiable programming in Julia via Zygote | Active |
| Equinox / Diffrax | 2021+ | Functional | Pytree-based modules on top of JAX | Active |
The progression from Theano to PyTorch 2.0 traces a slow convergence: static-graph systems are adding eager front-ends for usability, and eager-graph systems are adding compilation backends for performance. The two communities are arriving at the same answer from opposite directions.
Once a computational graph exists, compilers can apply transformations that would be impossible at the Python interpreter level.
| Compiler / format | Used by | What it does |
|---|---|---|
| XLA (Accelerated Linear Algebra) | TensorFlow, JAX, PyTorch/XLA | Operator fusion, layout transforms, sharding for TPUs and GPUs |
| TorchScript | PyTorch 1.x | JIT-compiled IR for Python-free deployment; being phased out in favor of torch.compile and ONNX exporter |
| TorchDynamo | PyTorch 2.0+ | Captures Python frames into FX graphs using Python's frame evaluation API |
| AOTAutograd | PyTorch 2.0+ | Traces forward and backward passes ahead of time so both can be compiled |
| TorchInductor | PyTorch 2.0+ | Lowers graphs to Triton kernels on GPU and C++/OpenMP on CPU |
| TVM | Many | Open-source ML compiler stack with two-level Relay/TIR IR |
| MLIR | TensorFlow, others | LLVM-style multi-level IR infrastructure for ML compilers |
| ONNX | Cross-framework | Interchange format for static graphs; widely used for inference |
| TensorRT | NVIDIA inference | Imports ONNX or PyTorch graphs and produces optimized engines for NVIDIA GPUs |
The most common transformation is operator fusion: collapsing several small ops (such as add, multiply, and relu) into a single kernel. This eliminates intermediate writes to global memory and reduces kernel-launch overhead. XLA, TVM, and Inductor all do this aggressively. Fusion is what allowed projects like Triton and FlashAttention to move attention from a memory-bound operation to a compute-bound one in modern LLM serving.
The core algorithm of every deep-learning trainer is reverse-mode autodiff over a computational graph. Step by step:
The efficiency of this scheme is what makes training models with hundreds of billions of parameters feasible. A single backward pass costs roughly 2 to 3 times a forward pass and produces gradients with respect to every parameter at once, regardless of how many parameters there are.
Real models need a handful of operations that go beyond plain forward and backward.
tensor.detach(), TensorFlow's tf.stop_gradient, and jax.lax.stop_gradient all prevent gradient flow through a sub-expression. This is essential for techniques like target networks in reinforcement learning, REINFORCE estimators, and certain generative training setups.torch.autograd.grad(create_graph=True) and JAX's nested grad calls produce gradients of gradients (Hessians, third derivatives, and so on). Used for meta-learning (MAML), influence functions, and physics-informed models.torch.autograd.Function, tf.custom_gradient, jax.custom_vjp).The activations stored during the forward pass for use during the backward pass are the dominant memory cost in modern training. Memory grows roughly linearly with the depth of the network and with the batch size. For a transformer, attention activations scale quadratically with sequence length unless techniques like FlashAttention recompute them on the fly.
Three common levers reduce memory use:
On the throughput side, fused kernels reduce kernel-launch overhead and global-memory traffic. FlashAttention, written in CUDA and later in Triton, restructures attention as a single fused kernel that never materializes the full attention matrix. Similar fusion at the graph level is what torch.compile and jax.jit automate.
Computational graphs sit at the intersection of several older ideas in computer science and applied mathematics.
| Concept | Connection |
|---|---|
| Dataflow programming | Computational graphs are a special case; nodes fire when inputs are available |
| Symbolic differentiation | The earliest CAS systems (Mathematica, Maple) built expression trees and applied differentiation rules; suffered from expression swell |
| Algorithmic differentiation (Griewank) | The classical autodiff literature predates ML by decades; reverse mode was known as adjoint mode |
| Differentiable programming | Yann LeCun's term for treating any differentiable program as a model; computational graphs are the substrate |
| Probabilistic programming | PyMC, Pyro, NumPyro, and Stan all rely on autodiff over computational graphs to sample posteriors |
| Neural ODEs (Chen et al., 2018) | Continuous-depth networks; gradients computed via the adjoint method, which is reverse-mode autodiff in continuous time |
| Differentiable simulators | Physics simulators built on top of JAX/PyTorch differentiate through entire simulations |
Seeing the graph helps with debugging and architecture design. Common tools:
| Tool | Framework | Output |
|---|---|---|
| TensorBoard | TensorFlow, PyTorch | Interactive web UI for graph and training metrics |
| torchview / torchsummary | PyTorch | Static diagrams of model architecture |
torch.fx GraphModule | PyTorch | Programmatic IR for inspection and transformation |
jax.make_jaxpr | JAX | Textual jaxpr representation of a traced function |
| Netron | ONNX, TF, PyTorch, Core ML | Cross-format graph viewer |
| TensorFlow Playground | TensorFlow | Browser visualizer for small networks |
| Network architecture diagrams | Any | Hand-drawn diagrams of model topology, often informally called computational graphs |
Computational graphs are powerful but not free.
jit traces are pure functions, so side effects and prints behave unexpectedly.The last few years have brought significant changes to how computational graphs are built and executed.
torch.compile, an opt-in compiler stack built on TorchDynamo, AOTAutograd, PrimTorch, and TorchInductor. It captures eager PyTorch code into graphs and produces Triton kernels on GPU.All of these continue to rest on the same DAG abstraction that Theano popularized in 2010. The substrate has been remarkably stable; what has changed is the sophistication of the compilers that consume it and the scale of the models that produce it.