By Kris Georgiev, Sam Park, Roy Rinberg, Shivam Garg, Andrew Ilyas, Aleksander Madry, and Seth Neel

Introduction

Machine learning (ML) models are trained on data to identify patterns and make predictions. Typically, when a model is successfully trained on a dataset, the knowledge it gains from that data becomes deeply embedded in its parameters and is even extractable given access to the model. However, there are cases where a model developer may realize that it was problematic to train on a subset of the data after the fact, and want to remove the influence of certain data points from the parameters of a trained model.

Consider the setting where a language model is trained on text data, and you want to modify the model to “unlearn” all the text from Harry Potter (perhaps for copyright reasons)—you’re tasked with producing a new model that never “read” any Harry Potter.

At this point, you may have two obvious first ideas:

  1. You can retrain your model on the full dataset, now excluding Harry potter. However, as models grow in size and training time, this can be become prohibitively expensive (both in terms of time and $$).

  2. You may also try to corrupt your original model’s performance on Harry Potter such that it performs poorly on Harry Potter; while this could make it forget Harry Potter, in practice, the model will also likely “forget” lots of English, as much of Harry Potter is not specific to the wizard universe.

Is there something cheaper (than re-training) and more principled we can do to unlearn? This challenge has motivated a recent line of work$^1$ on (approximate) machine unlearning, where the goal is to remove (or “unlearn”) the impact of a specific collection of training examples, called the “forget set” $S_F$, from a trained machine learning model $\theta$ trained on a dataset $S$, in a computationally efficient fashion.$^2$

Unlearning methods try to find a shortcut from the original model $\theta(S)$  to the retrained model  $\theta(S \setminus S_F )$

Unlearning methods try to find a shortcut from the original model $\theta(S)$ to the retrained model $\theta(S \setminus S_F )$

The typical notion of success for unlearning is that the unlearned model is indistinguishable from the model that has been fully retrained without the forget set (which we refer to as an “oracle”). As comparing overparameterized models directly is extremely difficult, in practice, we aim for indistinguishability of model outputs. This can be measured by retraining models on the dataset excluding the forget set (called the “retain set,” $S_R = S \setminus S_F$) and then measuring the distance between their predictions and the predictions from unlearned models. The choice of distance metric is both very important and quite subtle; we propose a new metric called KL divergence Of Margins (KLoM).$^3$

A high-level description of how unlearning evaluation generally looks (using a similar framing to the Neurips 2023 Unlearning challenge).

A high-level description of how unlearning evaluation generally looks (using a similar framing to the Neurips 2023 Unlearning challenge).

The “Missing Targets Problem

Failure Modes for Gradient-Based Unlearning

The majority of existing unlearning algorithms start with $\theta(S)$ and unlearn by fine-tuning the model. Specifically, they use some combination of:

  1. Gradient Ascent (GA) on the forget set $S_F$, in order to undo the impact of the points we want to forget.
  2. Gradient Descent (GD) on the retain set $S_R$, in order to reinforce the impact of the points that remain.

However, we find (as does other recent work) that this general approach comes with a significant set of drawbacks, which we collectively refer to as the missing targets problem.

  1. First, the assumption (underlying both gradient-based methods) that forget set points will increase in loss after unlearning and retain set points will not, does not always hold in practice. For example, if there are similar points in the forget and retain sets, excluding the forget set may increase the models loss on the similar points in the retain set; conversely, the model may perform just as well on the forget set after unlearning if the model can generalize sufficiently well from similar points in the retain set.

    For example, the text from Harry Potter “He made several important telephone calls and shouted a bit more” is very similar to text one might find in “The HitchHiker’s Guide to the Galaxy” or a legal case; thus, even after removing Harry Potter from our dataset, a perfect unlearner may be able to recite passages like this having learned from other texts.

  2. Second, even for a forget set point whose loss does increase, a perfect unlearner would not increase loss arbitrarily, but instead only until it reaches the expected loss under a perfectly retrained model — its “target value.”

Since we lack access to these target values, it is challenging to know when a given forget set point has been “unlearned” and thus many existing heuristic-based schemes often overshoot or undershoot the target loss for a given data point. Compounding this problem, is that different points may reach these different target values at different times over the run of the unlearning algorithm, meaning that no single stopping time achieves good unlearning performance across all points.

The figure below illustrates this phenomenon for a popular unlearning algorithm called SCRUB.$^4$ Over iterations of the algorithm, different points are unlearned (and then subsequently “overshot”) at different points in time (Hayes et al., 2024).

Each line represents the unlearning quality for an individual data point, as a function of the number of iterations of SCRUB. Datapoints that worsen their unlearning-quality over time are highlighted in red. This is a particular issue for forget points.

Each line represents the unlearning quality for an individual data point, as a function of the number of iterations of SCRUB. Datapoints that worsen their unlearning-quality over time are highlighted in red. This is a particular issue for forget points.