Training Non-Differentiable Networks via Optimal Transport

arXiv cs.RO / 5/5/2026

💬 OpinionDeveloper Stack & InfrastructureModels & Research

Key Points

  • The paper introduces PolyStep, a gradient-free optimizer for training neural networks that include non-differentiable components (e.g., spiking neurons, quantized layers, discrete routing, and black-box simulators), avoiding surrogate-gradient bias.
  • PolyStep updates parameters using only forward passes by evaluating losses at vertices of structured polytopes in a compressed subspace, forming a cost matrix, and moving particles toward low-cost vertices via barycentric projection tied to a regularized one-sided optimal transport limit.
  • Experiments show PolyStep can train “genuinely non-differentiable” models where prior gradient-free methods collapse, achieving 93.4% test accuracy on hard-LIF spiking networks and large improvements over other gradient-free baselines.
  • The method also demonstrates strong scalability and robustness, including MAX-SAT scaling from 100 to 1M variables, and RL policy search performance that remains stable under integer and binary quantization.
  • The authors provide convergence guarantees, including rates of O(log T / sqrt(T)) toward conservative-stationary points and stronger stationarity results (Clarke-stationary) for the main architectures, with theory consistent with known zeroth-order query-complexity lower bounds for forward-only methods.

Abstract

Neural networks increasingly embed non-differentiable components (spiking neurons, quantized layers, discrete routing, blackbox simulators, etc.) where backpropagation is inapplicable and surrogate gradients introduce bias. We present PolyStep, a gradient-free optimizer that updates parameters using only forward passes. Each step evaluates the loss at structured polytope vertices in a compressed subspace, computes softmax-weighted assignments over the resulting cost matrix, and displaces particles toward low-cost vertices via barycentric projection. This update corresponds to the one-sided limit of a regularized optimal-transport problem, inheriting its geometric structure without Sinkhorn iterations. PolyStep trains genuinely non-differentiable models where existing gradient-free methods collapse to near-random accuracy. On hard-LIF spiking networks we reach 93.4% test accuracy, outperforming all gradient-free baselines by over 60~pp and closing to within 4.4~pp of a surrogate-gradient Adam ceiling. Across four additional non-differentiable architectures (int8 quantization, argmax attention, staircase activations, hard MoE routing) we lead every gradient-free competitor. On MAX-SAT scaling from 100 to 1M variables, we sustain above 92% clause satisfaction while evolution strategies drop 8--12~pp. On RL policy search, we match OpenAI-ES on classical control and retain performance under integer and binary quantization that collapses gradient-based methods. We prove convergence to conservative-stationary points at rate O(\log T/\sqrt{T}) on piecewise-smooth losses, upgraded to Clarke-stationary on the headline architectures and extended to the piecewise-constant regime via a hitting-time bound. These rates match the known zeroth-order query-complexity lower bounds that all forward-only methods inherit. Code is available at https://github.com/anindex/polystep.