# Computational graph

> Source: https://aiwiki.ai/wiki/computational_graph
> Updated: 2026-05-01
> Categories: Deep Learning, Developer Tools
> From AI Wiki (https://aiwiki.ai), a free encyclopedia of artificial intelligence. Quote with attribution.

A **computational graph** is a directed acyclic graph (DAG) representation of a numerical computation, where nodes represent operations (or variables) and edges represent the data, typically tensors, that flow between them. It is the central abstraction of modern deep-learning frameworks: every neural network forward pass is built as a computational graph, and the backward pass that computes gradients is performed by traversing this graph in reverse to apply the chain rule of calculus, a procedure known as reverse-mode automatic differentiation.

The graph view of computation predates deep learning by several decades, with roots in dataflow programming and in classical algorithmic differentiation literature (Griewank, Wengert). It became the dominant abstraction in machine learning around 2010 with the release of [Theano](https://aiwiki.ai/wiki/theano), and was carried forward into [TensorFlow](/wiki/tensorflow), [PyTorch](/wiki/pytorch), [JAX](/wiki/jax), MXNet, and most other modern frameworks. Understanding computational graphs is foundational to understanding how [backpropagation](/wiki/backpropagation) works, why GPU acceleration is possible at the operation level, and how compilers such as XLA, TorchInductor, and TVM optimize neural-network computation.

## Definition and basic structure

Formally, a computational graph for a function f is a DAG G = (V, E) in which:

- Each node v in V represents either an input variable, a constant, or an elementary operation (such as multiply, add, exp, matmul, conv2d, or relu).
- Each directed edge (u, v) in E represents the value produced by node u and consumed by node v.
- The graph is acyclic, so a topological order exists. Forward evaluation proceeds in topological order; backward differentiation proceeds in reverse topological order.

A simple example is the expression z = (x + y) * (x - y). The graph has two input nodes (x and y), two intermediate operation nodes (add and subtract), and one output node (multiply). Evaluating in topological order computes the add and subtract first, then feeds them into the multiply. To differentiate, the framework records this graph and walks it backward to accumulate partial derivatives.

In deep-learning frameworks, the values flowing along edges are [tensors](/wiki/tensor): multi-dimensional arrays with a defined shape and dtype. The same DAG structure scales from a two-variable arithmetic example to transformer training graphs containing tens of thousands of nodes.

## Why computational graphs matter

Computational graphs are the data structure that makes modern deep learning practical. Three properties account for their dominance:

1. **Automatic differentiation.** Once the forward computation is recorded as a graph, gradients with respect to any input or parameter can be computed mechanically by reverse-mode autodiff. Researchers do not need to derive or hand-code derivatives for new model architectures.
2. **Hardware acceleration.** Each node corresponds to a known kernel that can be dispatched to a GPU, TPU, or other accelerator. The graph view makes it possible to identify groups of ops that can be fused into a single kernel, dramatically reducing memory traffic.
3. **Compilation and deployment.** Graphs can be serialized (for example, as ONNX or a saved model), shipped to inference servers, and re-optimized for different targets (TensorRT for NVIDIA GPUs, Core ML for Apple silicon, TFLite for mobile, and so on).

The abstraction is general enough to express not only feedforward networks but also recurrent networks, attention, convolutions, control flow, probabilistic models, and continuous-depth networks (Neural ODEs). Frameworks such as PyMC and Pyro reuse the same machinery for probabilistic programming.

## Two paradigms: static vs. dynamic graphs

Deep-learning frameworks fall into two broad camps depending on when the graph is constructed.

| Paradigm | Also known as | Construction time | Example frameworks | Strengths | Weaknesses |
|---|---|---|---|---|---|
| Static graph | Define-and-run, [graph execution](/wiki/graph_execution) | Graph is built once, then executed many times with different data | Theano, TensorFlow 1.x, Caffe, MXNet (Symbolic), JAX (under jit) | Whole-graph optimization, easier deployment, ahead-of-time compilation, predictable memory, easier serving | Awkward debugging, less Pythonic control flow, longer dev iteration |
| Dynamic graph | Define-by-run, [eager execution](/wiki/eager_execution) | Graph is built implicitly as Python code runs | Chainer, PyTorch, DyNet, TensorFlow 2.x default, Autograd | Pythonic, easy to debug with standard tools, natural support for variable-shape inputs and control flow | Harder to apply whole-program optimizations, traditionally slower for repeated identical computations |

The static approach was dominant from roughly 2010 to 2017 because it made it easier to extract performance from immature GPU stacks. The release of Chainer in 2015 introduced the define-by-run idea, which PyTorch adopted in 2017 and made mainstream. By 2019, TensorFlow had switched its 2.x default to eager mode while keeping a path to static graphs through `tf.function`. Today, most frameworks try to offer both: eager-by-default for development, with a way to compile hot regions into a static graph for production performance.

## Building a computational graph

A computational graph is built and used in two phases.

### Forward pass

During the forward pass, the framework computes the value of each node in topological order. In an eager framework like PyTorch, every Python expression that touches a `Tensor` triggers an actual kernel call and adds a node to the implicit graph. In a static framework like TensorFlow 1.x or under JAX's `jit`, the same Python code instead constructs a symbolic graph that will be executed later by a runtime or compiled to XLA.

For each node, the framework typically records:

- The operation type (the function being applied).
- References to its input tensors.
- The output tensor.
- A pointer to the corresponding gradient function for the backward pass.
- Optionally, any intermediate values needed to compute gradients (saved tensors).

### Backward pass

Given a scalar loss L, the backward pass computes dL/dx for every leaf tensor x that requires gradients. It does this by initializing dL/dL = 1 at the output, then walking the graph in reverse topological order. At each node, given the upstream gradient (the partial of L with respect to the node's output), it applies the local Jacobian to produce the gradient with respect to each of the node's inputs. These gradients flow back along the edges and accumulate at parameter nodes through the chain rule.

This is exactly reverse-mode automatic differentiation. In neural networks, where the input dimension (number of parameters) is far larger than the output dimension (a single scalar loss), reverse mode is dramatically more efficient than forward mode, which is why it became the standard.

## Modes of automatic differentiation

The choice of differentiation mode determines how the graph is traversed and how gradients are assembled. Baydin, Pearlmutter, Radul, and Siskind (2018) provide the canonical survey.

| Mode | What it computes | Cost (relative to forward) | When efficient | Used by |
|---|---|---|---|---|
| Forward-mode autodiff | Jacobian-vector products (JVP), one column of the Jacobian per pass | About 2 to 3 times the forward cost per JVP | Number of inputs <= number of outputs (tall Jacobian) | JAX (`jax.jvp`), PyTorch (`torch.autograd.functional.jvp`) |
| Reverse-mode autodiff (backpropagation) | Vector-Jacobian products (VJP), one row of the Jacobian per pass | About 2 to 4 times the forward cost per VJP, but stores activations | Number of inputs >> number of outputs (wide Jacobian, scalar loss) | All major DL frameworks |
| Symbolic differentiation | Closed-form derivative expression | Variable; can blow up due to expression swell | Small symbolic problems | Mathematica, SymPy, original Theano (partly) |
| Numerical differentiation | Finite differences | One forward eval per partial derivative | Tiny problems, gradient checking only | Often used to validate autodiff implementations |

For a typical neural network with millions of parameters and a single scalar loss, reverse-mode autodiff requires only a single backward pass to compute every gradient, which is why backpropagation has been the workhorse of deep-learning training since the 1980s. Forward mode is occasionally useful for things like Hessian-vector products (when combined with reverse mode) and for differentiating through tall functions where the output is higher-dimensional than the input.

## Key elements of a computational graph

Deep-learning frameworks share a small vocabulary of node types.

| Element | Role | Notes |
|---|---|---|
| [Tensor](/wiki/tensor) | Multi-dimensional array flowing along an edge | Has shape, dtype, and device; the basic data primitive |
| Operation (op) | Function applied at a node | Examples: matmul, add, conv2d, softmax, layernorm |
| Variable / Parameter | Tensor that requires gradients | Weights, biases; participates in optimizer updates |
| Constant | Tensor with no gradient | Input data in many setups, fixed lookup tables |
| Placeholder | Symbolic input slot (TF 1.x) | Filled at session run time; obsolete in TF 2.x eager mode |
| Saved tensor | Intermediate value cached for backward | Activations stored during forward pass |
| Gradient function | Backward op for each forward op | Encodes the local Jacobian of the operation |

Frameworks differ in how they expose these. PyTorch attaches a `grad_fn` attribute to every non-leaf tensor that points to the autograd Function used to compute the backward pass. TensorFlow 1.x had explicit `tf.placeholder` and `tf.Variable` types. JAX traces pure functions and produces a `jaxpr` (JAX program representation) that is then handed to XLA.

## Static vs. dynamic in detail

The practical difference between the two paradigms shows up in the developer workflow.

| Aspect | Static graph workflow | Dynamic graph workflow |
|---|---|---|
| When the graph is built | Once, before any data flows | Implicitly, as each op runs |
| Debugging | Hard; errors surface at session run, not at the offending Python line | Easy; standard Python debuggers work, errors point at the line |
| Variable shapes | Must be known (or symbolic) at graph build time | Can vary per call |
| Control flow | Encoded with framework-specific ops (`tf.cond`, `tf.while_loop`) | Native Python `if`, `for`, `while` |
| Optimization opportunities | High; whole-graph view enables operator fusion, layout transforms, scheduling | Limited unless a separate trace/compile step is added |
| Deployment | Easy to serialize and ship the graph | Requires a separate export step (TorchScript, ONNX, `torch.compile`) |
| Performance per call | Often higher after compilation | Often lower per op, but no compile latency |

Most modern systems blur the line. PyTorch 2.0 introduced `torch.compile`, which captures eager Python code into a graph and compiles it via TorchInductor. JAX's `jax.jit` decorator traces a pure Python function once and compiles the resulting `jaxpr` to XLA, giving define-by-run ergonomics with static-graph performance. TensorFlow's `tf.function` does the equivalent for TF 2.x.

## Major frameworks and their graph approaches

| Framework | First release | Default mode | Graph approach | Status |
|---|---|---|---|---|
| Theano | 2010 | Static | Symbolic computational graph in Python with C++/CUDA codegen | Discontinued in 2017; spiritual successor in PyMC's Aesara/PyTensor |
| Torch (Lua) | 2002 / DL focus 2011 | Eager | Imperative tensor library with hand-coded backprop modules | Largely replaced by PyTorch |
| Caffe | 2013 | Static | Layer-based prototxt graph | Caffe2 merged into PyTorch in 2018 |
| TensorFlow 1.x | 2015 | Static | `tf.placeholder` + `tf.Session.run`; explicit graphs | Superseded by 2.x |
| Chainer | 2015 | Dynamic | Pioneered define-by-run; influenced PyTorch | Maintenance mode since 2019 |
| MXNet | 2015 | Both | Symbolic API plus Gluon imperative API; hybridization for production | Apache repo archived November 2023 |
| PyTorch | 2017 | Dynamic | Autograd tape records ops as they run | Active; default research framework |
| DyNet | 2017 | Dynamic | Dynamic graph framework predating PyTorch | Maintained but niche |
| TensorFlow 2.x | 2019 | Eager | Eager by default, `tf.function` traces to static graphs | Active |
| JAX | 2018 | Functional | `grad` and `jit` transform pure Python functions; produces `jaxpr` for XLA | Active; growing in research |
| PyTorch 2.0 | March 2023 | Dynamic + compile | TorchDynamo captures graphs from eager Python; AOTAutograd handles backward; TorchInductor lowers to Triton/C++ | Active |
| Flux.jl | 2017 | Dynamic | Differentiable programming in Julia via Zygote | Active |
| Equinox / Diffrax | 2021+ | Functional | Pytree-based modules on top of JAX | Active |

The progression from Theano to PyTorch 2.0 traces a slow convergence: static-graph systems are adding eager front-ends for usability, and eager-graph systems are adding compilation backends for performance. The two communities are arriving at the same answer from opposite directions.

## Compilation and optimization

Once a computational graph exists, compilers can apply transformations that would be impossible at the Python interpreter level.

| Compiler / format | Used by | What it does |
|---|---|---|
| XLA (Accelerated Linear Algebra) | TensorFlow, JAX, PyTorch/XLA | Operator fusion, layout transforms, sharding for TPUs and GPUs |
| TorchScript | PyTorch 1.x | JIT-compiled IR for Python-free deployment; being phased out in favor of `torch.compile` and ONNX exporter |
| TorchDynamo | PyTorch 2.0+ | Captures Python frames into FX graphs using Python's frame evaluation API |
| AOTAutograd | PyTorch 2.0+ | Traces forward and backward passes ahead of time so both can be compiled |
| TorchInductor | PyTorch 2.0+ | Lowers graphs to Triton kernels on GPU and C++/OpenMP on CPU |
| TVM | Many | Open-source ML compiler stack with two-level Relay/TIR IR |
| MLIR | TensorFlow, others | LLVM-style multi-level IR infrastructure for ML compilers |
| ONNX | Cross-framework | Interchange format for static graphs; widely used for inference |
| [TensorRT](/wiki/tensorrt) | NVIDIA inference | Imports ONNX or PyTorch graphs and produces optimized engines for NVIDIA GPUs |

The most common transformation is **operator fusion**: collapsing several small ops (such as add, multiply, and relu) into a single kernel. This eliminates intermediate writes to global memory and reduces kernel-launch overhead. XLA, TVM, and Inductor all do this aggressively. Fusion is what allowed projects like Triton and FlashAttention to move attention from a memory-bound operation to a compute-bound one in modern LLM serving.

## Reverse-mode autodiff in detail

The core algorithm of every deep-learning trainer is reverse-mode autodiff over a computational graph. Step by step:

1. Define the loss L as a function of parameters θ and inputs x.
2. Build the computational graph for L by running the forward pass. Each op records its inputs, output, and a pointer to its gradient function.
3. Initialize the gradient at the output: dL/dL = 1.
4. Walk the graph in reverse topological order. At each node, given the upstream gradient ∂L/∂y where y is the node's output, compute ∂L/∂x_i = (∂L/∂y) * (∂y/∂x_i) for each input x_i. The local Jacobian ∂y/∂x_i is implemented by the op's backward function.
5. When several edges feed into the same node, sum the contributions. This is the multivariable chain rule.
6. Accumulate gradients at every leaf parameter node: dL/dθ.
7. Hand the gradients to an optimizer (SGD, Adam, AdamW, etc.) which updates θ.
8. Discard or reset the graph (in dynamic frameworks, the graph is rebuilt on every iteration; in static frameworks, the same graph is reused).

The efficiency of this scheme is what makes training models with hundreds of billions of parameters feasible. A single backward pass costs roughly 2 to 3 times a forward pass and produces gradients with respect to every parameter at once, regardless of how many parameters there are.

## Edge cases and special operations

Real models need a handful of operations that go beyond plain forward and backward.

- **Stop gradient.** PyTorch's `tensor.detach()`, TensorFlow's `tf.stop_gradient`, and `jax.lax.stop_gradient` all prevent gradient flow through a sub-expression. This is essential for techniques like target networks in reinforcement learning, REINFORCE estimators, and certain generative training setups.
- **Higher-order derivatives.** Because the backward pass is itself a computation that can be recorded into a graph, autograd works recursively. PyTorch's `torch.autograd.grad(create_graph=True)` and JAX's nested `grad` calls produce gradients of gradients (Hessians, third derivatives, and so on). Used for meta-learning (MAML), influence functions, and physics-informed models.
- **Custom autograd functions.** When an op has no closed-form gradient or a more efficient handwritten backward, frameworks let users define forward and backward manually (`torch.autograd.Function`, `tf.custom_gradient`, `jax.custom_vjp`).
- **Gradient checkpointing.** Trades compute for memory by discarding selected activations during the forward pass and recomputing them during the backward pass. A typical configuration cuts activation memory by around 60% at the cost of about 25% extra training time. Critical for training transformers that exceed GPU memory.
- **Mixed precision.** Forward and backward are run in FP16 or BF16 while a master copy of the weights stays in FP32. The framework casts on the fly and may use loss scaling to avoid underflow.
- **Stochastic operations.** Dropout uses a fixed mask during a single forward/backward; sampling from distributions uses the reparameterization trick or score-function estimators (REINFORCE) to provide a differentiable surrogate.
- **Distributed training.** DDP (distributed data parallel) replicates the model and averages gradients across devices using NCCL all-reduce. FSDP (fully sharded data parallel) shards parameters, gradients, and optimizer state. Both are implemented as additions to the standard graph.

## Memory and performance considerations

The activations stored during the forward pass for use during the backward pass are the dominant memory cost in modern training. Memory grows roughly linearly with the depth of the network and with the batch size. For a transformer, attention activations scale quadratically with sequence length unless techniques like FlashAttention recompute them on the fly.

Three common levers reduce memory use:

- **Gradient checkpointing**, as described above, recomputes selected blocks during the backward pass.
- **Compiler fusion** (XLA, Inductor) reduces intermediate buffers by combining ops into single kernels.
- **Activation offloading and quantization** moves activations to CPU memory or stores them in lower precision. Used in libraries like DeepSpeed and FSDP.

On the throughput side, fused kernels reduce kernel-launch overhead and global-memory traffic. FlashAttention, written in CUDA and later in Triton, restructures attention as a single fused kernel that never materializes the full attention matrix. Similar fusion at the graph level is what `torch.compile` and `jax.jit` automate.

## Connection to broader concepts

Computational graphs sit at the intersection of several older ideas in computer science and applied mathematics.

| Concept | Connection |
|---|---|
| Dataflow programming | Computational graphs are a special case; nodes fire when inputs are available |
| Symbolic differentiation | The earliest CAS systems (Mathematica, Maple) built expression trees and applied differentiation rules; suffered from expression swell |
| Algorithmic differentiation (Griewank) | The classical autodiff literature predates ML by decades; reverse mode was known as adjoint mode |
| Differentiable programming | Yann LeCun's term for treating any differentiable program as a model; computational graphs are the substrate |
| Probabilistic programming | PyMC, Pyro, NumPyro, and Stan all rely on autodiff over computational graphs to sample posteriors |
| Neural ODEs (Chen et al., 2018) | Continuous-depth networks; gradients computed via the adjoint method, which is reverse-mode autodiff in continuous time |
| Differentiable simulators | Physics simulators built on top of JAX/PyTorch differentiate through entire simulations |

## Visualization tools

Seeing the graph helps with debugging and architecture design. Common tools:

| Tool | Framework | Output |
|---|---|---|
| TensorBoard | TensorFlow, PyTorch | Interactive web UI for graph and training metrics |
| torchview / torchsummary | PyTorch | Static diagrams of model architecture |
| `torch.fx` GraphModule | PyTorch | Programmatic IR for inspection and transformation |
| `jax.make_jaxpr` | JAX | Textual `jaxpr` representation of a traced function |
| Netron | ONNX, TF, PyTorch, Core ML | Cross-format graph viewer |
| TensorFlow Playground | TensorFlow | Browser visualizer for small networks |
| Network architecture diagrams | Any | Hand-drawn diagrams of model topology, often informally called computational graphs |

## Limitations and challenges

Computational graphs are powerful but not free.

- **Memory consumption** is the central bottleneck for training large models, especially transformers, where attention activations dominate.
- **Compilation latency** is a real cost in workflows that change shapes or control flow often. JIT compilers re-trace when input shapes change, which can stall iteration.
- **Numerical precision.** Backprop can underflow in FP16; mixed-precision training and loss scaling exist to mitigate this.
- **Non-differentiable ops** like rounding, indexing with discrete indices, and argmax need workarounds (Gumbel-Softmax, straight-through estimators).
- **Stochastic ops.** Sampling is differentiable only under reparameterization or with surrogate gradient estimators.
- **Dynamic shapes and control flow** complicate compilation. Both PyTorch and JAX have invested heavily in handling shape polymorphism, with mixed results.
- **Debugging compiled graphs** is harder than debugging eager Python. PyTorch's TorchDynamo reports graph breaks; JAX's `jit` traces are pure functions, so side effects and prints behave unexpectedly.

## Recent developments

The last few years have brought significant changes to how computational graphs are built and executed.

- **PyTorch 2.0** (March 2023) introduced `torch.compile`, an opt-in compiler stack built on TorchDynamo, AOTAutograd, PrimTorch, and TorchInductor. It captures eager PyTorch code into graphs and produces Triton kernels on GPU.
- **JAX** has continued to grow in the research community for its functional purity and strong compilation story via XLA. Libraries like Flax, Haiku, and Equinox build neural-network abstractions on top.
- **Compiler-level fusion** has become essential for LLM serving. Triton, FlashAttention-2 and -3, and the broader inductor ecosystem squeeze more performance out of fixed hardware.
- **ML compiler stacks** like Mojo (from Modular), MLC LLM, and TVM Unity are pushing graph-based optimization further into model deployment.
- **Tensor parallelism and pipeline parallelism** for models that no longer fit on a single device are represented as additional structure within the graph, with collective ops (all-reduce, all-gather, reduce-scatter) treated as first-class operations.
- **Graph-level RLHF and inference engines** (vLLM, TensorRT-LLM, SGLang) carve out subgraphs for KV-cache reuse and continuous batching, blurring the line between model graph and serving runtime.

All of these continue to rest on the same DAG abstraction that Theano popularized in 2010. The substrate has been remarkably stable; what has changed is the sophistication of the compilers that consume it and the scale of the models that produce it.

## References

1. Goodfellow, I., Bengio, Y., and Courville, A. (2016). *Deep Learning*. MIT Press. Chapter 6.5 covers backpropagation and computational graphs. https://www.deeplearningbook.org/
2. Bergstra, J. et al. (2010). "Theano: A CPU and GPU Math Compiler in Python." *Proceedings of the 9th Python in Science Conference (SciPy 2010)*. https://conference.scipy.org/proceedings/scipy2010/bergstra.html
3. Abadi, M. et al. (2016). "TensorFlow: A System for Large-Scale Machine Learning." *Proceedings of OSDI '16*, pp. 265-283. https://www.usenix.org/system/files/conference/osdi16/osdi16-abadi.pdf
4. Paszke, A. et al. (2017). "Automatic Differentiation in PyTorch." *NIPS Autodiff Workshop*. https://openreview.net/forum?id=BJJsrmfCZ
5. Paszke, A. et al. (2019). "PyTorch: An Imperative Style, High-Performance Deep Learning Library." *NeurIPS 2019*. https://arxiv.org/abs/1912.01703
6. Frostig, R., Johnson, M. J., and Leary, C. (2018). "Compiling Machine Learning Programs via High-Level Tracing." *MLSys 2018*. https://cs.stanford.edu/~rfrostig/pubs/jax-mlsys2018.pdf
7. Baydin, A. G., Pearlmutter, B. A., Radul, A. A., and Siskind, J. M. (2018). "Automatic Differentiation in Machine Learning: A Survey." *Journal of Machine Learning Research*, 18(153):1-43. https://www.jmlr.org/papers/volume18/17-468/17-468.pdf
8. Maclaurin, D., Duvenaud, D., and Adams, R. P. (2015). "Autograd: Effortless Gradients in NumPy." *AutoML Workshop, ICML 2015*.
9. Griewank, A. and Walther, A. (2008). *Evaluating Derivatives: Principles and Techniques of Algorithmic Differentiation*, 2nd edition. SIAM.
10. Tokui, S. et al. (2019). "Chainer: A Deep Learning Framework for Accelerating the Research Cycle." *KDD 2019*. https://arxiv.org/abs/1908.00213
11. PyTorch team. "PyTorch 2.x: torch.compile, TorchDynamo, AOTAutograd, TorchInductor." https://pytorch.org/get-started/pytorch-2-x/
12. JAX team. "Just-in-time compilation" and "jaxpr" documentation. https://docs.jax.dev/
13. Apache MXNet team. "MXNet 2.0 API Deprecation RFC" and repository archival notice (November 2023). https://github.com/apache/mxnet
14. Chen, R. T. Q., Rubanova, Y., Bettencourt, J., and Duvenaud, D. (2018). "Neural Ordinary Differential Equations." *NeurIPS 2018*. https://arxiv.org/abs/1806.07366
15. Snider, D. and Liang, R. (2023). "Operator Fusion in XLA: Analysis and Evaluation." https://arxiv.org/abs/2301.13062
16. Chen, T. et al. (2016). "Training Deep Nets with Sublinear Memory Cost" (gradient checkpointing). https://arxiv.org/abs/1604.06174

