Jin Xu
Researcher in Machine Learning
Microsoft Research
I am a senior researcher in Efficient AI at Microsoft. My current research focuses on self-improving AI for GPU kernel optimization and the design of novel large language model architectures that improve efficiency on modern hardware accelerators. More broadly, I am interested in the theoretical foundations of deep learning, efficient model design, and the interplay among optimization methods, architectural choices, and hardware systems.
Previously, I completed my Ph.D. in the Department of Statistics at the University of Oxford, supervised by Yee Whye Teh and Tom Rainforth, as part of the OxCSML group. My doctoral research explored meta-learning, geometric deep learning, and statistical machine learning, with an emphasis on building scalable, principled and sample-efficient statistical models.
Interests
- •The AI Scientists
- •LLM Pretraining
- •Geometric Deep Learning
- •Statistical Machine Learning
Education
Ph.D. in Statistical Machine Learning
University of Oxford
M.Sc. in Artificial Intelligence
University of Edinburgh
B.Sc. in Mathematics and Applied Mathematics
Fudan University
Publications
Revisiting Transformer Layer Parameterization Through Causal Energy Minimization
Jin Xu, Camille Couturier, Victor Rühle, Saravan Rajmohan, James Hensman
arXiv preprint arXiv:2605.07588
Transformer blocks typically combine multi-head attention (MHA) for token mixing with gated MLPs for token-wise feature transformation, yet many choices in their parameterization remain largely empirical. We introduce Causal Energy Minimization (CEM), a framework that recasts Transformer layers as optimization steps on conditional energy functions while explicitly accounting for layer parameterization. Extending prior energy-based interpretations of attention, CEM shows that weight-tied MHA can be derived as a gradient update on an interaction energy, and that a gated MLP with shared up/down projections can be viewed through an element-wise energy. This perspective identifies a design space for Transformer layers that includes within-layer weight sharing, diagonal-plus-low-rank interactions, lightweight preconditioners, and recursive updates. We evaluate CEM-derived layers in language-modeling experiments at the moderate hundred-million-parameter scale. Despite their constrained parameterizations, these layers train stably and can match corresponding Transformer baselines. Overall, our results suggest that CEM provides a useful lens for understanding Transformer layer parameterization, connecting Transformer architectures to energy-based models and motivating further exploration of energy-guided layer designs.
TANDEM: Bi-Level Data Mixture Optimization with Twin Networks
Jiaxing Wang, Deping Xiang, Jin Xu, Mingyang Yi, Guoqiang Gong, Zicheng Zhang, Haoran Li, Pengzhang Liu, Zhen Chen, Ke Zhang, Ju Fan, Qixia Jiang
Advances in Neural Information Processing Systems 38 (NeurIPS 2025)
The capabilities of large language models (LLMs) significantly depend on training data drawn from various domains. Optimizing domain-specific mixture ratios can be modeled as a bi-level optimization problem, which we simplify into a single-level penalized form and solve with twin networks: a proxy model trained on primary data and a dynamically updated reference model trained with additional data. Our proposed method, Twin Networks for bi-level DatA mixturE optiMization (TANDEM), measures the data efficacy through the difference between the twin models and up-weights domains that benefit more from the additional data. TANDEM provides theoretical guarantees and wider applicability, compared to prior approaches. Furthermore, our bi-level perspective suggests new settings to study domain reweighting such as data-restricted scenarios and supervised fine-tuning, where optimized mixture ratios significantly improve the performance. Extensive experiments validate TANDEM's effectiveness in all scenarios.
OdysseyBench: Evaluating LLM Agents on Long-Horizon Complex Office Application Workflows
Weixuan Wang, Dongge Han, Daniel Madrigal Diaz, Jin Xu, Victor Rühle, Saravan Rajmohan
arXiv preprint arXiv:2508.09124
Autonomous agents powered by large language models (LLMs) are increasingly deployed in real-world applications requiring complex, long-horizon workflows. However, existing benchmarks predominantly focus on atomic tasks that are self-contained and independent, failing to capture the long-term contextual dependencies and multi-interaction coordination required in realistic scenarios. To address this gap, we introduce OdysseyBench, a comprehensive benchmark for evaluating LLM agents on long-horizon workflows across diverse office applications including Word, Excel, PDF, Email, and Calendar. Our benchmark comprises two complementary splits: OdysseyBench+ with 300 tasks derived from real-world use cases, and OdysseyBench-Neo with 302 newly synthesized complex tasks. Each task requires agent to identify essential information from long-horizon interaction histories and perform multi-step reasoning across various applications. To enable scalable benchmark creation, we propose HomerAgents, a multi-agent framework that automates the generation of long-horizon workflow benchmarks through systematic environment exploration, task generation, and dialogue synthesis. Our extensive evaluation demonstrates that OdysseyBench effectively challenges state-of-the-art LLM agents, providing more accurate assessment of their capabilities in complex, real-world contexts compared to existing atomic task benchmarks.
Beyond Logits: Aligning Feature Dynamics for Effective Knowledge Distillation
Guoqiang Gong, Jiaxing Wang, Jin Xu, Deping Xiang, Zicheng Zhang, Leqi Shen, Yifeng Zhang, JunhuaShu JunhuaShu, ZhaolongXing ZhaolongXing, Zhen Chen, Pengzhang Liu, Ke Zhang
Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)
Knowledge distillation (KD) compresses large language models (LLMs), known as teacher models, into lightweight versions called student models, enabling efficient inference and downstream applications. However, prevailing approaches accomplish this by predominantly focusing on matching the final output distributions of student/teacher models. Drawing on the perspective that transformers can be viewed as discretizing ordinary differential equation (ODEs) on integer time steps (corresponding to layer indices), where intermediate features evolve across layers, we argue that effective KD requires aligning the entire feature dynamics between teacher and student models, which we call feature dynamics distillation (FDD). This alignment involves matching both the feature trajectory and its first-order derivative, rather than just the final states. Our approach extends the original KD objective with two additional loss terms: layer-wise feature KD, which matches discretized feature trajectory, and layer feature delta KD, which matches first-order changes in features across adjacent layers. Extensive experiments on various tasks validate the effectiveness of our distillation method.
Enhancing Reasoning Capabilities of Small Language Models with Blueprints and Prompt Template Search
Dongge Han, Menglin Xia, Daniel Madrigal, Samuel Kessler, Ankur Mallick, Xuchao Zhang, Mirian Del Carmen Hipolito Garcia, Jin Xu, Victor Rühle, Saravan Rajmohan
ICML 2025 Workshop on Theoretical and Theoretical Foundations of Deep Learning on Graphs and Foundation Models (TTODLer-FM)
Small language models (SLMs) offer promising and efficient alternatives to large language models (LLMs). However, SLMs' limited capacity restricts their reasoning capabilities and makes them sensitive to prompt variations. To address these challenges, we propose a novel framework that enhances SLM reasoning capabilities through LLM generated blueprints. The blueprints provide structured, high-level reasoning guides that help SLMs systematically tackle related problems. Furthermore, our framework integrates a prompt template search mechanism to mitigate the SLMs' sensitivity to prompt variations. Our framework demonstrates improved SLM performance across various tasks, including math (GSM8K), coding (MBPP), and logic reasoning (BBH). Our approach improves the reasoning capabilities of SLMs without increasing model size or requiring additional training, offering a lightweight and deployment-friendly solution for on-device or resource-constrained environments.
On Feature Learning in Structured State Space Models
Leena Chennuru Vankadara*, Jin Xu*, Moritz Haas, Volkan Cevher
NeurIPS 2024
This paper studies the scaling behavior of state-space models (SSMs) and structured variants such as Mamba, focusing on their capability to learn features in the infinite-width limit. We show that common scaling rules (like Maximal Update Parameterization) fail to support feature learning in SSMs, and that spectral scaling conditions often effective elsewhere do not apply. A detailed signal propagation analysis (forward and backward) uncovers a scaling regime enabling non-trivial feature evolution in infinite-width SSMs, offering improved stability, generalization, and hyper-parameter transfer.
μP²: Effective Sharpness Aware Minimization Requires Layerwise Perturbation Scaling
Moritz Haas, Jin Xu, Volkan Cevher, Leena Chennuru Vankadara
Advances in Neural Information Processing Systems
Sharpness Aware Minimization (SAM) enhances performance across various neural architectures and datasets. As models are continually scaled up to improve performance, a rigorous understanding of SAM's scaling behaviour is paramount. To this end, we study the infinite-width limit of neural networks trained with SAM, using the Tensor Programs framework. Our findings reveal that the dynamics of standard SAM effectively reduce to applying SAM solely in the last layer in wide neural networks, even with optimal hyperparameters. In contrast, we identify a stable parameterization with layerwise perturbation scaling, which we call Maximal Update and Perturbation Parameterization (μP²), that ensures all layers are both feature learning and effectively perturbed in the limit. Through experiments with MLPs, ResNets and Vision Transformers, we empirically demonstrate that μP² is the first parameterization to achieve hyperparameter transfer of the joint optimum of learning rate and perturbation radius across model scales.
Deep Stochastic Processes via Functional Markov Transition Operators
Jin Xu, Emilien Dupont, Kaspar Märtens, Tom Rainforth, Yee Whye Teh
Advances in Neural Information Processing Systems
We introduce Markov Neural Processes (MNPs), a new class of Stochastic Processes (SPs) constructed by stacking sequences of neural-parameterised Markov transition operators in function space. We prove that these operators preserve exchangeability and consistency, adding flexibility to Neural Processes without compromising consistency. Experiments demonstrate clear advantages of MNPs over baselines across tasks.
Group Equivariant Subsampling
Jin Xu, Hyunjik Kim, Tom Rainforth, Yee Whye Teh
Neural Information Processing Systems
We introduce translation- and group-equivariant subsampling/upsampling layers to construct exactly equivariant CNNs and group-equivariant autoencoders. Learned representations generalize to unseen positions and orientations and show improved data efficiency and object-centric decomposition.
MetaFun: Meta-Learning with Iterative Functional Updates
Jin Xu, Jean-Francois Ton, Hyunjik Kim, Adam R Kosiorek, Yee Whye Teh
International Conference on Machine Learning
We develop a functional encoder–decoder approach to supervised meta-learning, where labelled data are encoded into infinite-dimensional functional representations via learned iterative updates. The final representation conditions a decoder for predictions. Our approach is the first encoder–decoder meta-learner to achieve state-of-the-art on miniImageNet and tieredImageNet.
Controllable Probabilistic Semantic Image Inpainting
Jin Xu, Yee Whye Teh
arXiv preprint
We develop a method for user-controllable semantic image inpainting using a deep generative model combining an encoder for arbitrary observed pixels, disentangled latent variables, and a bidirectional PixelCNN. Our method generates plausible, coherent inpaintings matching user-specified semantics while remaining consistent with observations.