Scaling Expert Language Models with Unsupervised Domain Discovery
Large language models are typically trained densely: all parameters are updated with respect to all inputs. This requires synchronization of billions of parameters across thousands of GPUs. We introduce a simple but effective method to asynchronously train large, sparse language models on arbitrary text corpora. Our method clusters a corpus into sets of related documents, trains a separate expert language model on each cluster, and combines them in a sparse ensemble for inference. This approach generalizes embarrassingly parallel training by automatically discovering the domains for each expert, and eliminates nearly all the communication overhead of existing sparse language models. Our technique outperforms dense baselines on multiple corpora and few-shot tasks, and our analysis shows that specializing experts to meaningful clusters is key to these gains. Performance also improves with the number of experts and size of training data, suggesting this is a highly efficient and accessible approach to training large language models.
Introduction. Language models (LMs) are trained on up to trillions of tokens of text (Hoffmann et al., 2022; Touvron et al., 2023). This improves performance on many tasks, but also incurs an extreme cost: thousands of GPUs need to be active simultaneously to update all parameters at each step (Zhang et al., 2022; Chowdhery et al., 2022). Branch-Train-Merge (BTM; Li et al. 2022) alleviates this cost by dividing the total compute among a collection of smaller expert language models (ELMs), each independently trained on a distinct subset (or domain) of the training corpus and ensembled
Discussion / Conclusion. We introduce C-BTM, a new technique to efficiently train sparse LMs. C-BTM splits a corpus into k clusters, trains an expert LM on each cluster, and creates a sparse ensemble during inference. We observe that the optimal number of clusters for C-BTM increases with the amount of data, and using more clusters also allows us to aggressively parallelize training to efficiently scale into massive datasets. Future work could investigate C-BTM in multitask or multilingual settings, the usefulness of multiple iterations of C-BTM on a corpus (perhaps with hierarchical clustering), or the possibility of combining metadata- and cluster-based routing to scale into many heterogeneous datasets in parallel.