Mechanistic Interpretability: Review of Scalable Parameter Decomposition
This paper, Towards Scalable Parameter Decomposition, is the one I’m reviewing. It was authored by Lucius Bushnaq, Dan Braun, and Lee Sharkey at Goodfire (with work primarily carried out while at Apollo Research). The paper was published on June 27, 2025, and can be found on arXiv. The research blog post by Michael Byun is available on the Goodfire Research site.
Before diving into SPD, let’s establish some context. The goal of mechanistic interpretability is to reverse engineer a neural network into its constituent “mechanisms” or “circuits.” This is analogous to decompiling a program to understand its functions and subroutines.
There are two main spaces to perform this decomposition:
1. Activation Space: Methods like Sparse Dictionary Learning (SDL) or Sparse Autoencoders (SAEs) try to find a basis of “features” in the activation space. They posit that any activation vector is a sparse linear combination of these feature vectors. This has conceptual issues, such as ignoring the geometry of features and not directly decomposing the network’s function.
2. Parameter Space: This is the newer, more ambitious approach. It posits that the network’s weight matrices themselves can be decomposed. A network’s parameters,
can be seen as a single point in a very high-dimensional space. The idea is to express this point as a sum of simpler vectors, each representing a distinct mechanism.
$$W = \sum_{c=1}^{C} \theta_c$$
C can be larger than rank of the weights enabling superposition.
This is the core idea of Linear Parameter Decomposition (LPD), the framework to which both APD (the predecessor) and SPD (this paper’s contribution) belong.
The LPD Optimization Problem
The paper states that a good decomposition must satisfy three properties. Let’s frame this as a multi-objective optimization problem, which is what these methods are implicitly solving. We want to find a set of parameter components
$$ \{\theta_c\}_{c=1}^C $$
that:
1.Faithfulness (Equality Constraint): The components must sum back to the original network’s parameters.
$$\sum_c \theta_c = W$$
In practice, this is implemented as a “soft” constraint via a loss term:
$$\mathcal{L}_{\text{faithfulness}} = ||W - \sum_c \theta_c||_F^2$$
Minimizing this ensures the decomposition is a valid representation of the original model.
2.Minimality (Sparsity in Use): For any given input
$$x$$
the number of components needed to replicate the model’s output should be as small as possible. A network should only use a small subset of its mechanisms for any single task.
Constraint Formulation: This is the hard part. Let
$$S(x) \subset \{1, \dots, C\}$$
be the set of “active” components for input
$$x$$
We want to minimize
$$|S(x)|$$
on average over the data distribution. The challenge is defining and finding this set
$$S(x)$$
APD and SPD offer different solutions to this.
3.Simplicity (Component-wise Regularization): Each individual component
$$\theta_c$$
should be as simple as possible, in the paper notated as Minimum Description Length. A good explanation is a simple one. If our “mechanisms”
$$\theta_c$$
are themselves incredibly complex (e.g., dense, full-rank, spanning all layers), they aren’t very good explanatory primitives. We want them to be simple, for example, low-rank or localized to a single layer.
Constraint Formulation: This is a regularization term on the components themselves, e.g., minimizing the nuclear norm (sum of singular values) to encourage low-rankness:
$$\sum_c ||\theta_c||_*$$
APD tried to solve this by using attribution scores to guess the active set
$$S(x)$$
and then applying a simplicity regularizer. It was fragile and computationally expensive. SPD provides a more elegant, end-to-end differentiable solution.
The SPD Method
SPD tackles the LPD optimization problem with several key innovations.
The Decomposition Basis: Overcomplete Rank-One Subcomponents
Instead of decomposing the entire parameter vector
$$W$$
at once, SPD decomposes each weight matrix
$$W^l$$
individually. Furthermore, it chooses the simplest possible basis for a matrix: rank-one matrices.
Any matrix
$$W^l \in \mathbb{R}^{m \times n}$$
can be written as a sum of rank-one matrices using Singular Value Decomposition (SVD). However, SVD gives at most
$$\min(m, n)$$
orthogonal components. SPD uses a more flexible decomposition:
$$W^l \approx \sum_{c=1}^C \vec{U^l_c} \vec{V_c^{l \top}}$$
where
$$\vec{U^l_c}$$
is a column vector and
$$\vec{V_c^{l \top}}$$
is a row vector. Each term
$$\vec{U^l_c} \vec{V_c^{l \top}}$$
is a rank-one matrix called a subcomponent.
Key Insight: The number of subcomponents,
$$C$$
can be larger than the rank of
$$W^l$$
This is called an overcomplete basis. This is crucial for models of superposition, where a network squeezes more features into a layer than its dimensionality would naively allow. You need more basis vectors (subcomponents) than dimensions to represent them.
This choice automatically handles the Simplicity objective. By defining the primitives to be rank-one matrices, we don’t need an explicit simplicity loss like APD did, which (as we’ll see) avoids issues like parameter shrinkage.
Minimality via Stochastic Ablation
How do we enforce minimality without relying on brittle attribution methods and a fixed `top-k`? SPD’s answer: turn it into a problem of robustness to noise.
Defining Causal Importance: The paper defines a subcomponent’s causal importance on input
$$x$$
as the extent to which it cannot be ablated.
Let
$$g^l_c(x) \in [0, 1]$$
be the learned “causal importance” of subcomponent
$$c$$
in layer
$$l$$
for input
$$x$$
$$g^l_c(x) = 1$$
means the component is fully essential and cannot be touched.
$$g^l_c(x) = 0$$
means the component is completely irrelevant and can be ablated to any degree without affecting the output.
The Stochastic Mask: To implement this, they introduce a random mask,
$$m^l_c$$
for each subcomponent. The weight matrix used in a forward pass is not the simple sum, but a stochastically masked sum:
$$W’^l(x) = \sum_{c=1}^C m^l_c(x) \cdot (\vec{U^l_c} \vec{V_c^{l \top}})$$
The mask
$$m^l_c(x)$$
is a random variable sampled from a uniform distribution:
$$m^l_c(x) \sim \mathcal{U}(g^l_c(x), 1)$$
Intuition: If a component is important
$$(g_c \approx 1)$$
the mask is sampled from
$$\mathcal{U}(1, 1)$$
so it’s always
$$1$$
(no ablation). If a component is unimportant
$$(g_c \approx 0)$$
the mask is sampled from
$$\mathcal{U}(0, 1)$$
meaning it is randomly scaled down, on average by 50%. The model must learn a decomposition that is robust to this random “destruction” of unimportant components.
The Optimization Landscape: This setup creates a beautiful tension between two loss terms:
1. Stochastic Reconstruction Loss
$$\mathcal{L}_{\text{stochastic-recon}}$$ This loss penalizes differences between the original model’s output
$$f(x, W)$$
and the stochastically masked model’s output
$$f(x, W’(x))$$
$$\mathcal{L}_{\text{stochastic-recon}} = \mathbb{E}_{r \sim \mathcal{U}(0,1)} \left[ D(f(x, W’(x, r)), f(x, W)) \right]$$
This loss pushes the causal importance values
$$g^l_c(x)$$
towards
$$1$$
If
$$g=1$$
for all components, there is no stochasticity, and the loss is minimized (assuming faithfulness is met).
2. Importance Minimality Loss
$$\mathcal{L}_{\text{importance-minimality}}$$ This is a simple Lp norm on the causal importance values, acting as a sparsity-inducing regularizer.
$$\mathcal{L}_{\text{importance-minimality}} = \sum_{l,c} |g^l_c(x)|^p$$
This loss pushes all
$$g^l_c(x)$$
towards
$$0$$
Convergence Analysis: The optimizer must find a balance. For a given component
$$c$$
if ablating it severely hurts the reconstruction loss, the optimizer will be forced to increase
$$g_c(x)$$
to protect it, despite the penalty from the minimality loss. If, however, a component can be ablated with little harm, the minimality loss will successfully push its importance
$$g_c(x)$$
to zero. This dynamic *discovers* the minimal set of essential components for each input.
Crucially, this entire process is differentiable (using the reparameterization trick for the sampling step), turning a discrete subset selection problem into a continuous, stochastic optimization problem.
Why SPD is a Theoretical Improvement Over APD
1. No `top-k` Hyperparameter (Continuity): APD required choosing `k`the number of active components. This introduces a hard, non-differentiable selection function. Gradient-based optimization struggles with such discontinuities. SPD’s stochastic masking is “soft”; the expected number of active components emerges naturally from the optimization rather than being a fixed hyperparameter. The loss landscape is smoother.
2. Universal Gradient Flow (Avoiding Dead Components): In APD, if a component is never selected in the `top-k` for any batch, it receives no gradient signal and can never learn. It’s a “dead” component. In SPD, every subcomponent participates in the forward pass, via its mask.
$$m_c$$
Even if
$$g_c=0$$
the mask can be non-zero, allowing a gradient to flow back. This ensures all parameters are constantly updated and can “compete” to be useful.
3. Decoupled Objectives (Avoiding Shrinkage): APD’s simplicity loss (e.g., spectral norm) was in direct tension with its faithfulness loss. To make a component “simple,” the optimizer would shrink its norm, but to make the components sum to
$$W$$
the norms had to be large. This led to a compromised solution with shrunken components. SPD’s minimality pressure is on the *importance value*
$$g_c$$
not on the subcomponent norms
$$(\vec{U_c}, \vec{V_c})$$
The job of getting the norms right is left entirely to the
$$\mathcal{L}_{\text{faithfulness}}$$
loss, leading to a cleaner optimization and decompositions with no shrinkage (as seen in the ML2R metric being ~1.0).
4. Direct vs. Approximated Causality: APD used gradient attributions to estimate importance. This is a first-order Taylor approximation of a component’s effect:
$$\Delta \text{Output} \approx \nabla_{\theta_c} \text{Output} \cdot \Delta \theta_c$$
Such linear approximations are notoriously unreliable for large perturbations (like fully ablating a component). SPD, via its stochastic reconstruction loss, directly measures the effect of large, non-linear ablations. It learns a much more robust and empirically grounded model of causal importance.
Analysis of Key Results
Toy Model of Superposition (TMS): SPD correctly identifies the ground truth mechanisms (the columns of the weight matrix) with near-perfect alignment (MMCS ≈ 1.0) and magnitude (ML2R ≈ 1.0). This shows it works on a canonical problem and avoids the shrinkage issue of APD.
TMS with Identity Matrix: This is a crucial sanity check against “feature splitting.” An identity matrix is a single, dense, rank-
$$m_1$$
mechanism. An overzealous decomposition method might try to split it into many sparse “features.” SPD correctly decomposes the identity matrix into exactly rank-one subcomponents that are always active together. It doesn’t find more components than necessary, because doing so would violate faithfulness without improving the other losses. This provides strong evidence that LPD, and SPD in particular, discovers functional mechanisms rather than arbitrary sparse features.
Toy Model of Compressed Computation: This result is subtle and important. SPD decomposes the input matrix
$$W_{in}$$
into 100 subcomponents, one for each input feature, as expected. However, it decomposes the output matrix
$$W_{out}$$
into 50 subcomponents that are all active simultaneously for every input. This means it correctly identifies
$$W_{out}$$
as a single, monolithic, rank-50 mechanism responsible for projecting the MLP’s computation back into the residual stream. APD got this wrong, splitting
$$W_{out}$$
into 100 components, likely an artifact of its flawed `top-k` attribution method. SPD’s solution is more parsimonious and aligns better with theoretical models.
The Significance of SPD
Stochastic Parameter Decomposition is not just an incremental improvement as a practical and scalable tool.
Reframing the Problem: It reframes the difficult, combinatorial problem of finding the minimal active set of mechanisms as a continuous, stochastic optimization problem about learning robustness to ablation.
Improved Optimization: By removing hard discontinuities and ensuring universal gradient flow, it creates a much more stable and reliable training process.
Better Decompositions: The resulting decompositions are more accurate, avoid artifacts like parameter shrinkage on APD, and hopefully provide better explanations of model function as demonstrated in the compressed computation model.
