Yihe Dong
Home / Publications
See my Google Scholar profile for a more up-to-date list.
Publications
Attention Retrieves, MLP Memorizes: Disentangling Trainable Components in the Transformer
Yihe Dong, Lorenzo Noci, Mikhail Khodak, Mufan Li
arXiv preprint
arXiv
/
abstract
The transformer architecture is central to the success of modern Large Language Models (LLMs), in part due to its surprising ability to perform a wide range of tasks - including mathematical reasoning, memorization, and retrieval - using only gradient-based learning on next-token prediction. While the core component of a transformer is the self-attention mechanism, we question how much, and which aspects, of the performance gains can be attributed to it. To this end, we compare standard transformers to variants in which either the MLP layers or the attention weights are frozen at initialization. Surprisingly, we find that attention with frozen key and query weights is not only able to form induction heads, but can also perform competitively on language modeling. We formalize this by proving a new expressivity result for transformer models with frozen key and query weights. To further isolate the contribution of attention, we design MixiT, an architecture with entirely random attention scores, with provably stable signal propagation that overcomes prior depth-wise scaling challenges in random transformers. We use the successes and failures of MixiT to understand the role each transformer component plays, such as attention being largely responsible for in-context reasoning, and MLPs being responsible for, but collaborates with attention, on knowledge storage. Our results suggest that the transformer architecture has a built-in inductive bias towards forming specialized circuits, as it does even without learnable attention weights.
Metadata Conditioning Accelerates Language Model Pre-training
Tianyu Gao, Alexander Wettig, Luxi He, Yihe Dong, Sadhika Malladi, Danqi Chen
ICML 2025
arXiv
/
abstract
The vast diversity of styles, domains, and quality levels present in language model pre-training corpora is essential in developing general model capabilities, but efficiently learning and deploying the correct behaviors exemplified in each of these heterogeneous data sources is challenging. To address this, we propose a new method, termed Metadata Conditioning then Cooldown (MeCo), to incorporate additional learning cues during pre-training. MeCo first provides metadata (e.g., URLs like www.wikipedia.org) alongside the text during training and later uses a cooldown phase with only the standard text, thereby enabling the model to function normally even without metadata. MeCo significantly accelerates pre-training across different model scales (600M to 8B parameters) and training sources (C4, RefinedWeb, and DCLM). For instance, a 1.6B language model trained with MeCo matches the downstream task performance of standard pre-training while using 33% less data. Additionally, MeCo enables us to steer language models by conditioning the inference prompt on either real or fabricated metadata that encodes the desired properties of the output: for example, prepending wikipedia.org to reduce harmful generations or factquizmaster.com (fabricated) to improve common knowledge task performance. We also demonstrate that MeCo is compatible with different types of metadata, such as model-generated topics. MeCo is remarkably simple, adds no computational overhead, and demonstrates promise in producing more capable and steerable language models.
Cache Me If You Can: How Many KVs Do You Need for Effective Long-Context LMs?
Adithya Bhaskar, Alexander Wettig, Tianyu Gao, Yihe Dong, Danqi Chen
arXiv preprint
arXiv
/
abstract
Language models handle increasingly long contexts for tasks such as book summarization, but this leads to growing memory costs for the key-value (KV) cache. Many prior works have proposed ways of discarding KVs from memory, but their approaches are tailored to favorable settings, obscuring caveats like high peak memory and performance degradation, and a fair comparison between methods is difficult. In this paper, we propose the KV footprint as a unified metric, which accounts for both the amount of KV entries stored and their lifespan in memory. We evaluate methods based on the smallest footprint they attain while preserving performance in both long-context understanding and generation, with context lengths of up to 128K tokens. This metric reveals the high peak memory of prior KV eviction methods. One class of methods -- post-fill eviction -- has a high footprint due to being incompatible with eviction during pre-filling. We adapt these methods to be able to evict KVs during pre-filling, achieving substantially lower KV footprints. We then turn to recency eviction methods, wherein we propose PruLong, an end-to-end optimization method for learning which attention heads need to retain the full KV cache and which do not. PruLong saves memory while preserving long-context performance, achieving 12% smaller KV footprint than prior methods while retaining performance in challenging recall tasks. Our paper clarifies the complex tangle of long-context inference methods and paves the way for future development to minimize the KV footprint.
AdaRank: Disagreement Based Module Rank Prediction for Low-rank Adaptation
Yihe Dong
arXiv preprint
arXiv
/
abstract
With the rise of language and multimodal models of ever-increasing size, pretraining a general-purpose foundational model and adapting it to downstream tasks has become common practice. To this end, adaptation efficiency can be a critical bottleneck given the large model sizes, hence efficient finetuning methods such as LoRA have become prevalent. However, LoRA is typically applied with the same rank across all model layers, despite mounting evidence from transfer learning literature that during finetuning, later layers diverge more from pretrained weights. Inspired by the theory and observations around feature learning and module criticality, we develop a simple model disagreement based technique to predict the rank of a given module relative to the other modules. Empirically, AdaRank generalizes notably better on unseen data than using uniform ranks with the same number of parameters. Compared to prior work, AdaRank has the unique advantage of leaving the pretraining and adaptation stages completely intact: no need for any additional objectives or regularizers, which can hinder adaptation accuracy and performance.
SLM: End-to-end Feature Selection via Sparse Learnable Masks
Yihe Dong, Sercan O. Arik
arXiv preprint
arXiv
/
abstract
Feature selection has been widely used to alleviate compute requirements during training, elucidate model interpretability, and improve model generalizability. We propose SLM -- Sparse Learnable Masks -- a canonical approach for end-to-end feature selection that scales well with respect to both the feature dimension and the number of samples. At the heart of SLM lies a simple but effective learnable sparse mask, which learns which features to select, and gives rise to a novel objective that provably maximizes the mutual information (MI) between the selected features and the labels, which can be derived from a quadratic relaxation of mutual information from first principles. In addition, we derive a scaling mechanism that allows SLM to precisely control the number of features selected, through a novel use of sparsemax. This allows for more effective learning as demonstrated in ablation studies. Empirically, SLM achieves state-of-the-art results against a variety of competitive baselines on eight benchmark datasets, often by a significant margin, especially on those with real-world challenges such as class imbalance.
LANISTR: Multimodal Learning from Structured and Unstructured Data
Sayna Ebrahimi, Sercan O. Arik, Yihe Dong, Tomas Pfister
arXiv preprint
arXiv
/
abstract
Multimodal large-scale pretraining has shown impressive performance for unstructured data such as language and image. However, a prevalent real-world scenario involves structured data types, tabular and time-series, along with unstructured data. Such scenarios have been understudied. To bridge this gap, we propose LANISTR, an attention-based framework to learn from LANguage, Image, and STRuctured data. The core of LANISTR's methodology is rooted in masking-based training applied across both unimodal and multimodal levels. In particular, we introduce a new similarity-based multimodal masking loss that enables it to learn cross-modal relations from large-scale multimodal data with missing modalities. On two real-world datasets, MIMIC-IV (from healthcare) and Amazon Product Review (from retail), LANISTR demonstrates remarkable improvements, 6.6% (in AUROC) and 14% (in accuracy) when fine-tuned with 0.1% and 0.01% of labeled data, respectively, compared to the state-of-the-art alternatives. Notably, these improvements are observed even with very high ratio of samples (35.7% and 99.8% respectively) not containing all modalities, underlining the robustness of LANISTR to practical missing modality challenge.
COSTAR: Improved Temporal Counterfactual Estimation
Chuizheng Meng, Yihe Dong, Sercan Ö. Arık, Yan Liu, Tomas Pfister
arXiv preprint
arXiv
/
abstract
Estimation of temporal counterfactual outcomes from observed history is crucial for decision-making in many domains such as healthcare and e-commerce, particularly when randomized controlled trials (RCTs) suffer from high cost or impracticality. For real-world datasets, modeling time-dependent confounders is challenging due to complex dynamics, long-range dependencies and both past treatments and covariates affecting the future outcomes. In this paper, we introduce Counterfactual Self-Supervised Transformer (COSTAR), a novel approach that integrates self-supervised learning for improved historical representations. We propose a component-wise contrastive loss tailored for temporal treatment outcome observations and explain its effectiveness from the view of unsupervised domain adaptation. COSTAR yields superior performance in estimation accuracy and generalization to out-of-distribution data compared to existing models, as validated by empirical results on both synthetic and real-world datasets.
Koopman Neural Forecaster for Time Series with Temporal Distribution Shifts
Rui Wang, Yihe Dong, Sercan Ö. Arik, Rose Yu
ICLR 2023
arXiv
/
abstract
Temporal distributional shifts, with underlying dynamics changing over time, frequently occur in real-world time series and pose a fundamental challenge for deep neural networks (DNNs). In this paper, we propose a novel deep sequence model based on the Koopman theory for time series forecasting: Koopman Neural Forecaster (KNF) which leverages DNNs to learn the linear Koopman space and the coefficients of chosen measurement functions. KNF imposes appropriate inductive biases for improved robustness against distributional shifts, employing both a global operator to learn shared characteristics and a local operator to capture changing dynamics, as well as a specially-designed feedback loop to continuously update the learned operators over time for rapidly varying behaviors. We demonstrate that KNF achieves superior performance compared to the alternatives, on multiple time series datasets that are shown to suffer from distribution shifts.
Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth
Yihe Dong, Jean-Baptiste Cordonnier, Andreas Loukas
ICML 2021 (Oral)
arXiv
/
abstract
Attention-based architectures have become ubiquitous in machine learning, yet our understanding of the reasons for their effectiveness remains limited. This work proposes a new way to understand self-attention networks: we show that their output can be decomposed into a sum of smaller terms, each involving the operation of a sequence of attention heads across layers. Using this decomposition, we prove that self-attention possesses a strong inductive bias towards "token uniformity". Specifically, without skip connections or multi-layer perceptrons (MLPs), the output converges doubly exponentially to a rank-1 matrix. On the other hand, skip connections and MLPs stop the output from degeneration. Our experiments verify the identified convergence phenomena on different variants of standard transformer architectures.
COPT: Coordinated Optimal Transport on Graphs
Yihe Dong, Will Sawin
NeurIPS 2020
arXiv
/
abstract
We introduce COPT, a novel distance metric between graphs defined via an optimization routine, computing a coordinated pair of optimal transport maps simultaneously. This gives an unsupervised way to learn general-purpose graph representation, applicable to both graph sketching and graph comparison. COPT involves simultaneously optimizing dual transport plans, one between the vertices of two graphs, and another between graph signal probability distributions. We show theoretically that our method preserves important global structural information on graphs, in particular spectral information, and analyze connections to existing studies. Empirically, COPT outperforms state of the art methods in graph classification on both synthetic and real datasets.
CoinPress: Practical Private Mean and Covariance Estimation
Sourav Biswas, Yihe Dong, Gautam Kamath, Jonathan Ullman
NeurIPS 2020
arXiv
/
abstract
We present simple differentially private estimators for the mean and covariance of multivariate sub-Gaussian data that are accurate at small sample sizes. We demonstrate the effectiveness of our algorithms both theoretically and empirically using synthetic and real-world datasets -- showing that their asymptotic error rates match the state-of-the-art theoretical bounds, and that they concretely outperform all previous methods. Specifically, previous estimators either have weak empirical accuracy at small sample sizes, perform poorly for multivariate data, or require the user to provide strong a priori estimates for the parameters.
HNHN: Hypergraph Networks with Hyperedge Neurons
Yihe Dong, Will Sawin, Yoshua Bengio
Graph Representations and Beyond Workshop at ICML 2020
arXiv
/
abstract
Hypergraphs provide a natural representation for many real world datasets. We propose a novel framework, HNHN, for hypergraph representation learning. HNHN is a hypergraph convolution network with nonlinear activation functions applied to both hypernodes and hyperedges, combined with a normalization scheme that can flexibly adjust the importance of high-cardinality hyperedges and high-degree vertices depending on the dataset. We demonstrate improved performance of HNHN in both classification accuracy and speed on real world datasets when compared to state of the art methods.
Learning Sublinear-Time Indexing for Nearest Neighbor Search
Yihe Dong, Piotr Indyk, Ilya Razenshteyn, Tal Wagner
ICLR 2020
arXiv
/
abstract
Space partitions of $\mathbb{R}^d$ underlie a vast and important class of fast nearest neighbor search (NNS) algorithms. Inspired by recent theoretical work on NNS for general metric spaces, we develop a new framework for building space partitions reducing the problem to balanced graph partitioning followed by supervised classification. We instantiate this general approach with the KaHIP graph partitioner and neural networks, respectively, to obtain a new partitioning procedure called Neural Locality-Sensitive Hashing (Neural LSH). On several standard benchmarks for NNS, our experiments show that the partitions obtained by Neural LSH consistently outperform partitions found by quantization-based and tree-based methods as well as classic, data-oblivious LSH.
Scalable Nearest Neighbor Search for Optimal Transport
Arturs Backurs, Yihe Dong, Piotr Indyk, Ilya Razenshteyn, Tal Wagner
ICML 2020
arXiv
/
abstract
The Optimal Transport (a.k.a. Wasserstein) distance is an increasingly popular similarity measure for rich data domains, such as images or text documents. This raises the necessity for fast nearest neighbor search algorithms according to this distance, which poses a substantial computational bottleneck on massive datasets. In this work we introduce Flowtree, a fast and accurate approximation algorithm for the Wasserstein-1 distance. We formally analyze its approximation factor and running time. We perform extensive experimental evaluation of nearest neighbor search algorithms in the $W_1$ distance on real-world dataset. Our results show that compared to previous state of the art, Flowtree achieves up to 7.4 times faster running time.
Quantum Entropy Scoring for Fast Robust Mean Estimation and Improved Outlier Detection
Yihe Dong, Samuel B. Hopkins, Jerry Li
NeurIPS 2019 (Spotlight)
arXiv
/
abstract
We study two problems in high-dimensional robust statistics: robust mean estimation and outlier detection. In robust mean estimation the goal is to estimate the mean $\mu$ of a distribution on $\mathbb{R}^d$ given $n$ independent samples, an $\varepsilon$-fraction of which have been corrupted by a malicious adversary. In outlier detection the goal is to assign an outlier score to each element of a data set such that elements more likely to be outliers are assigned higher scores. Our algorithms for both problems are based on a new outlier scoring method we call QUE-scoring based on quantum entropy regularization. For robust mean estimation, this yields the first algorithm with optimal error rates and nearly-linear running time $\widetilde{O}(nd)$ in all parameters, improving on the previous fastest running time $\widetilde{O}(\min(nd/\varepsilon^6, nd^2))$. For outlier detection, we evaluate the performance of QUE-scoring via extensive experiments on synthetic and real data, and demonstrate that it often performs better than previously proposed algorithms.
SANNS: Scaling Up Secure Approximate k-Nearest Neighbors Search
Hao Chen, Ilaria Chillotti, Yihe Dong, Oxana Poburinnaya, Ilya Razenshteyn, M. Sadegh Riazi
USENIX Security 2020
arXiv
/
abstract
The $k$-Nearest Neighbor Search ($k$-NNS) is the backbone of several cloud-based services such as recommender systems, face recognition, and database search on text and images. In these services, the client sends the query to the cloud server and receives the response in which case the query and response are revealed to the service provider. Such data disclosures are unacceptable in several scenarios due to the sensitivity of data and/or privacy laws. In this paper, we introduce SANNS, a system for secure $k$-NNS that keeps client's query and the search result confidential. SANNS comprises two protocols: an optimized linear scan and a protocol based on a novel sublinear time clustering-based algorithm. We prove the security of both protocols in the standard semi-honest model. The protocols are built upon several state-of-the-art cryptographic primitives such as lattice-based additively homomorphic encryption, distributed oblivious RAM, and garbled circuits. We provide several contributions to each of these primitives which are applicable to other secure computation tasks. Both of our protocols rely on a new circuit for the approximate top-$k$ selection from $n$ numbers that is built from $O(n + k^2)$ comparators. We have implemented our proposed system and performed extensive experimental results on four datasets in two different computation environments, demonstrating more than 18-31× faster response time compared to optimally implemented protocols from the prior work. Moreover, SANNS is the first work that scales to the database of 10 million entries, pushing the limit by more than two orders of magnitude.
A Study of Performance of Optimal Transport
Yihe Dong, Yu Gao, Richard Peng, Ilya Razenshteyn, Saurabh Sawlani
Optimal Transport in Machine Learning Workshop at NeurIPS 2019
arXiv
/
abstract
We investigate the problem of efficiently computing optimal transport (OT) distances, which is equivalent to the node-capacitated minimum cost maximum flow problem in a bipartite graph. We compare runtimes in computing OT distances on data from several domains, such as synthetic data of geometric shapes, embeddings of tokens in documents, and pixels in images. We show that in practice, combinatorial methods such as network simplex and augmenting path based algorithms can consistently outperform numerical matrix-scaling based methods such as Sinkhorn and Greenkhorn, even in low accuracy regimes, with up to orders of magnitude speedups. Lastly, we present a new combinatorial algorithm that improves upon the classical Kuhn-Munkres algorithm.
NLP and Large-Scale Information Retrieval on Mathematical Texts
Yihe Dong
Springer Lecture Notes in Computer Science, 2018
Springer
/
abstract
We develop methods for applying natural language processing techniques to mathematical texts, enabling large-scale information retrieval in mathematical literature. Our approach addresses the unique challenges of processing mathematical notation and semantics.
NLP-Based Detection of Mathematics Subject Classification
Yihe Dong
Springer Lecture Notes in Computer Science, 2018
Springer
/
abstract
We present a machine learning approach for automatically detecting Mathematics Subject Classification (MSC) codes from mathematical texts using natural language processing techniques. Our method achieves high accuracy in classifying mathematical papers into appropriate subject categories.