Revisiting Auxiliary Losses for Conditional Depth Routing: An Empirical Study

arXiv cs.LG / 4/21/2026

📰 NewsDeveloper Stack & InfrastructureIdeas & Deep AnalysisModels & Research

Key Points

  • The paper investigates how auxiliary losses affect training stability for conditional depth token routing, where a gating module sends some tokens through a cheap FFN and the rest through a full FFN across controlled layers.
  • It compares two gate designs (G1: MLP utility scoring vs. G3: JEPA-guided action-conditional prediction) on a 157.5M decoder-only model with controller-only training and a 50% full-path budget, finding G3 improves early-to-mid optimization in 3/3 runs under a standard util/rank auxiliary-loss recipe.
  • Ablation results show that removing the util/rank auxiliary supervision improves best/average LM and threshold-hit speed for both gates, and the earlier advantage of G3 over G1 disappears.
  • The authors attribute the utility/rank losses’ negative effect to an off-policy oracle labeling assumption (that all subsequent layers run full) which mismatches gated execution that routes only part of the tokens through the full path.
  • Eliminating util/rank also reduces the training compute proxy (about 1.53× to 1.07× full-only), suggesting a practical efficiency benefit within the tested regime.

Abstract

Conditional depth execution routes a subset of tokens through a lightweight cheap FFN while the remainder execute the standard full FFN at each controlled layer. The central difficulty is gate training: the gate decision must propagate through many layers before it influences the language modeling (LM) loss, so the resulting gradients are weak and noisy. Auxiliary losses are commonly stacked to stabilise training, yet the interactions among them -- particularly between a predictive auxiliary and explicit score supervision -- have not been systematically compared under controlled conditions. We evaluate two gate designs under a 157.5M-parameter decoder-only model with controller-only training, 50% full-path budget, and 3-seed runs on a fineweb-edu subset. The MLP gate (G1) maps the current hidden state to a utility score; the JEPA-guided gate (G3) adds an action-conditional predictor that forecasts, in a low-dimensional latent space, the outcome of executing full vs. cheap per token, aligned against a fixed target head. Under the standard recipe with oracle-style utility regression and pairwise rank supervision (util/rank), G3 improves early-to-mid optimisation over G1 in 3/3 seeds (lower avg LM, faster threshold hits, ~10.3x lower grad norms), with 20k-step endpoint LM within a 0.005 heuristic reference. A key finding (ablation A3): jointly removing util/rank improves best/avg LM and threshold-hit speed in 3/3 seeds for both gates, and the early-to-mid advantage of G3 over G1 disappears. We trace this to an off-policy oracle label that assumes all subsequent layers execute full, whereas gated execution routes only a fraction through full -- making util/rank net-negative under the current recipe. Removing util/rank also cuts the training FLOPs proxy from ~1.53x to ~1.07x full-only (2.87h to 1.75h on a V100-32GB, ~39%). Conclusions are scoped to the studied regime.