Introducing RadixMLP: Intra-batch deduplication for causal transformers
Blog post from Baseten
RadixMLP is a technique developed to enhance the efficiency of batch inference workloads for causal transformer models by eliminating redundant computations for sequences sharing common prefixes. It leverages the position-wise nature of MLPs, LayerNorms, linear projections, and embeddings, dynamically mapping batches to a prefix trie to gather shared segments into a compressed form for computation, thus reducing redundant processing. In practice, RadixMLP has demonstrated significant speed improvements, achieving 1.44–1.59x speedups in realistic reranking workloads and up to 5x on synthetic benchmarks with longer shared prefixes, and it is integrated into platforms like Baseten Embeddings Inference. By maintaining a stateless design that operates entirely within a single forward pass, RadixMLP provides cache-like benefits without the overheads of persistent state management, making it suitable for scenarios with high prefix redundancy. Its compatibility with training and open-source release under the MIT License further enhances its accessibility and applicability for optimizing model performance in various inference tasks.