Implicit Chain of Thought Reasoning via Knowledge Distillation

Paper · arXiv 2311.01460 · Published November 2, 2023
Chain-of-Thought and Reasoning Methods

To augment language models with the ability to reason, researchers usually prompt or finetune them to produce chain of thought reasoning steps before producing the final answer. However, although people use natural language to reason effectively, it may be that LMs could reason more effectively with some intermediate computation that is not in natural language. In this work, we explore an alternative reasoning approach: instead of explicitly producing the chain of thought reasoning steps, we use the language model’s internal hidden states to perform implicit reasoning. The implicit reasoning steps are distilled from a teacher model trained on explicit chain-of-thought reasoning, and instead of doing reasoning “horizontally” by producing intermediate words one-by-one, we distill it such that the reasoning happens “vertically” among the hidden states in different layers. We conduct experiments on a multi-digit multiplication task and a grade school math problem dataset and find that this approach enables solving tasks previously not solvable without explicit chain-of-thought, at a speed comparable to no chain-of-thought.

Introduction. Large language models have demonstrated significant capabilities in tasks that demand both language understanding and reasoning, such as multi-hop question answering (Yang et al., 2018; Yao et al., 2023b) and solving math problems (Hendrycks et al., 2021; Cobbe et al., 2021; Welleck et al., 2022; Wei et al., 2022b; Kojima et al., 2022; Chen et al., 2022; Yue et al., 2023; Chern et al., 2023). To elicit their reasoning abilities, a prevalent paradigm has been the chainof-thought reasoning approach (Nye et al., 2021; Wei et al., 2022b; Kojima et al., 2022). Under this paradigm, models are trained or prompted to articulate intermediate steps before producing the final answer. Although this approach aligns with human problem-solving strategies, it might not fully leverage the computational potential of these language models. Consider the transformer architecture (Vaswani et al., 2017), which can manifest computation both “horizontally” by generating words in sequence and “vertically” by processing through its many layers of internal hidden states.

Discussion / Conclusion. In this work, we proposed the concept of implicit chain of thought reasoning for transformer-based language models, where reasoning is performed “vertically” among the transformer hidden states, instead of being performed “horizontally” in the form of generating intermediate tokens. This concept potentially enables the model to break away from the human-like reasoning process and develop its own internal reasoning process. To operationalize this concept, we proposed a three-step approach—mind-reading the teacher, thought emulation, and coupling and optimization, where the high-level idea is to distill the knowledge of a teacher trained for horizontal reasoning into a student and an emulator trained for vertical reasoning.