Training Nonlinear Transformers for Chain-of-Thought Inference: A Theoretical Generalization Analysis
Chain-of-Thought (CoT) is an efficient prompting method that enables the reasoning ability of large language models by augmenting the query using multiple examples with multiple intermediate steps. Despite the empirical success, the theoretical understanding of how to train a Transformer to achieve the CoT ability remains less explored. This is primarily due to the technical challenges involved in analyzing the nonconvex optimization on nonlinear attention models. To the best of our knowledge, this work provides the first theoretical study of training Transformers with nonlinear attention to obtain the CoT generalization capability so that the resulting model can inference on unseen tasks when the input is augmented by examples of the new task. We first quantify the required training samples and iterations to train a Transformer model towards CoT ability. We then prove the success of its CoT generalization on unseen tasks with distribution-shifted testing data. Moreover, we theoretically characterize the conditions for an accurate reasoning output by CoT even when the provided reasoning examples contain noises and are not always accurate.
Introduction. Transformer-based large-scale foundation models, such as GPT-3 (Brown et al., 2020), GPT-4 (OpenAI, 2023), LLaMa (Touvron et al., 2023a;b), and Sora (Liu et al., 2024), have demonstrated remarkable success across various tasks, including natural language processing (Brown et al., 2020; Touvron et al., 2023b), multimodal learning (OpenAI, 2023; Radford et al., 2021), and image/video generation (OpenAI, 2023; Liu et al., 2024). What is more surprising is that large language models (LLMs) demonstrate reasoning ability through the so-called “Chain-of-Thought” (CoT) method (Wei et al., 2022). The objective is to let a pre-trained LLM generate K steps of reasoning given input query xquery without any fine-tuning. To achieve that, the input xquery is augumented with l examples {xi, {yi,j}K j=1}l i=1 of a certain K-step reasoning task, where each xi is the input with yi,j as the j-th reasoning step, and yi,K is the final output.
Discussion / Conclusion. This paper theoretically analyzes the training dynamics of Transformers with nonlinear attention, together with the CoT generalization ability of the resulting model on new tasks with noisy and partially inaccurate context examples. We quantitatively characterize and compare the required conditions for the success of CoT and ICL. Although based on a simplified Transformer model and reasoning tasks operating on patterns, this work deepens the theoretical understanding of the CoT mechanism. Future directions include designing efficient prompt-generating methods for CoT and analyzing LLM reasoning on a more complicated data model.