A Mechanistic Interpretation of Arithmetic Reasoning in Language Models using Causal Mediation Analysis
Mathematical reasoning in large language models (LMs) has garnered significant attention in recent work, but there is a limited understanding of how these models process and store information related to arithmetic tasks within their architecture. In order to improve our understanding of this aspect of language models, we present a mechanistic interpretation of Transformer-based LMs on arithmetic questions using a causal mediation analysis framework. By intervening on the activations of specific model components and measuring the resulting changes in predicted probabilities, we identify the subset of parameters responsible for specific predictions. This provides insights into how information related to arithmetic is processed by LMs. Our experimental results indicate that LMs process the input by transmitting the information relevant to the query from mid-sequence early layers to the final token using the attention mechanism. Then, this information is processed by a set of MLP modules, which generate result-related information that is incorporated into the residual stream.
Introduction. Mathematical reasoning with Transformer-based models (Vaswani et al., 2017) is challenging as it requires an understanding of the quantities and the mathematical concepts involved. While large language models (LMs) have recently achieved impressive performance on a set of math-based tasks (Wei et al., 2022a; Chowdhery et al., 2022; OpenAI, 2023), their behavior has been shown to be inconsistent and context-dependent (Bubeck et al., 2023). Recent literature shows a multitude of works proposing methods to improve the performance of large LMs on math benchmark datasets through enhanced pre-training (Spokoyny et al., 2022; Lewkowycz et al., 2022; Liu and Low, 2023) or specific prompting techniques (Wei et al., 2022b; Kojima et al., 2022; Yang et al., 2023, inter alia). However, there is a limited understanding of the inner workings of these models and how they store and process information to correctly perform mathbased tasks.
Discussion / Conclusion. We proposed the use of causal mediation analysis to mechanistically investigate how LMs process information related to arithmetic. Through controlled interventions on specific subsets of the model, we assessed the impact of these mediators on the model’s predictions. We posited that models produce predictions to arithmetic queries by conveying the math-relevant information from the mid-sequence early layers to the last token, where this information is then processed by late MLP modules. We carried out a causality-grounded experimental procedure on four different Transformer-based LMs, and we provided empirical evidence supporting our hypothesis. Furthermore, we showed that the information flow we observed in our experiments is specific to arithmetic queries, compared to two other tasks that do not involve arithmetic computation. Our findings suggest potential avenues for research into model pruning and more targeted training/fine-tuning by concentrating on specific model components associated with certain queries or computations.