Abstract
It is widely believed that complex machine learning models generally encode features through linear representations. This is the foundational hypothesis behind a vast body of work on interpretability. A key challenge toward extracting interpretable features, however, is that they exist in superposition. In this work, we study the question of extracting features in superposition from a learning theoretic perspective. We start with the following fundamental setting: we are given query access to a function \[ f(x)=\sum_{i=1}^n \sigma_i(v_i^\top x), \] where each unit vector v_i encodes a feature direction and \sigma_i:\R\to\R is an arbitrary response function and our goal is to recover the v_i and the function f.
In learning-theoretic terms, superposition refers to the \emph{overcomplete regime}, when the number of features is larger than the underlying dimension (i.e. n > d), which has proven especially challenging for typical algorithmic approaches. Our main result is an efficient query algorithm that, from noisy oracle access to f, identifies all feature directions whose responses are non-degenerate and reconstructs the function f. Crucially, our algorithm works in a significantly more general setting than all related prior results. We allow for essentially arbitrary superpositions, only requiring that v_i, v_j are not nearly identical for i
eq j, and allowing for general response functions \sigma_i. At a high level, our algorithm introduces an approach for searching in Fourier space by iteratively refining the search space to locate the hidden directions v_i.