Equipping language models for systematic reasoning
Access full-text files
Date
Authors
Journal Title
Journal ISSN
Volume Title
Publisher
Abstract
With increases in the scale of pretraining corpora and model parameter budgets, recent large language models (LMs) have become capable enough that they can now be applied directly to textual reasoning tasks. However, these models continue to exhibit surprisingly acute shortcomings in their ability to carry out procedural reasoning and consistently apply rules (Dziri et al., 2023; Qiu et al., 2023; Tang et al., 2023, i.a.). Our goal is to address these shortcomings—to enable models to solve problems expressed in natural language by reasoning through them using a systematic inference process. In order to equip LMs with such a mechanism while retaining their ability to handle natural language input, we investigate hybridizing LMs and classical proof search methods for automated reasoning. We approach this hybridization from two angles: 1) specializing LMs to make consistent elementary inferences, then building those steps into a structured search procedure to carry out reasoning in language, and 2) augmenting LMs with a learnable proof search module, then training the resulting models to use this module when making predictions. We investigate approach 1) by first setting out to develop a consistent model for the generation of single-step textual inferences, where we find that distilling symbolic reasoning patterns into models by incorporating semi-synthetic data into the training distribution leads to models that more reliably generate valid inferences. We then consider how to improve the consistency of multi-step textual deduction. To this end, we incorporate our single-step textual inference models into a structured search procedure guided by a learned goal-directed heuristic. We experimentally demonstrate that our system, SCSearch, produces multi-step deductions with much higher internal consistency than the baseline autoregressive approach. We explore approach 2) by directly fusing a soft proof search layer onto the architecture of a transformer language model. The resulting proof search augmented language models (PSALMs) are entirely differentiable and support training with labels, proof traces, reference rules, or any combination thereof. We conduct experiments to determine the best way to train PSALMs for deductive reasoning, and find that by guiding rule representations to unify consistently we can achieve out-of-distribution generalization where vanilla transformers can’t.