Antidistillation Sampling
Summary
Zico Kolter introduces "Antidistillation Sampling," a computationally efficient method to generate text from closed-source models while rendering the outputs ineffective as training data for other models.
SESSION Transcript
I'm going to talk today about some work we did recently on anti-distillation sampling. To motivate this, this is actually a fairly well-contained problem and one that I think is pretty interesting. The basic question is the following: can we release a model, a closed-source, query-only black-box model, where you can generate samples from this model, but training on those samples does not improve a student model?
Now there are obvious competitive implications for doing this. I think companies want to release closed-source models without being able to have open-source variants that are as good as them. There are also some real safety implications too. I think there's value in at least a period of having closed-source models that are at the frontier ahead of open-source models to understand their capabilities. I would rather not live in a world where every time a closed-source model is released, we can easily clone it with an open-source variant. That's my personal belief. Beyond that, I just think it'd be interesting to think about, can you create samples that somehow aren't very useful for training on? That's the goal of antidistillation sampling.
Here is the setting. We have a teacher model, and we can draw samples from the teacher model. The teacher model specifies some distribution over next tokens given previous tokens, and we can sample according to it. All you can get actually as a third party is just samples from this model. We also have a student model, but we actually don't know quite what student we're going to actually—someone downstream might be using, so we're going to think of it actually as a proxy model. We're going to have some proxy student model, and we're going to think of this model as being trained on the samples from the teacher. So whenever we take a sample from the teacher model 𝑥𝑡+1, we would then actually take a small gradient step on that token in the proxy model. That's how we update these proxy model parameters. Finally, we have some downstream loss, ℒ, which is a function of our proxy model. This measures how good the proxy model is. It measures how good it is on some, in this case, differentiable downstream task, like maybe how good its likelihood it is in a large corpus of data that we think corresponds to something like the task at hand.
Now what we want to do, and the goal of antidistillation sampling is really quite simple. We want to sample from the teacher model but do it in a way that mixes together two different objectives. On the one hand, you want to sample terms that have a high likelihood under the teacher model, so that's the first term. That's just normal temperature-based sampling as you would normally have when you're sampling any LLM. The second term is the interesting one here. It trades off between, again, tokens that are likely under the teacher model with tokens that cause the loss of this proxy model to actually increase when you train on that one token. Just again, going back to the notation here, 𝜃+𝑃 is the updated proxy model. We want the loss of the updated model to be larger than the loss of the original model. You try to minimize the loss normally, so we want the loss to be larger. We want this difference here to be positive. Therefore, we're going to sample tokens that tend to favor increases in loss for the proxy model.
I want to make the point here that doing this naively is not really practical. This loss ℒ is pretty complex to evaluate. It takes some time. Think of it as the loss of the student model on some very large corpus. That loss depends on the token you sample. In theory, you would have to evaluate this loss for every possible next token every time you want to generate a single sample from your model on some large downstream corpus. I don't have the time to go through it; this is the fun math part of this paper. It turns out that using some clever tricks coming from basically just differentiation (they're not that clever, but they were fun to work on) it turns out you can evaluate that thing efficiently for all possible next tokens, so for all possible sampling points. You can find the next sample that maximizes that tradeoff using just one forward pass like you would normally have to in the teacher model and two forward passes in the proxy model.
All right, so I won't go over this. It involves just the directional derivatives and finite difference approximations done in two different ways, which lets you do this efficiently. OK, that's the approach. That’s actually all I'll have here. I'll have to defer it to the paper for the actual details, but it's relatively cheap to compute this. The question is, does it work? The answer is kind of. I need to make the big point here, this is not a practical approach yet by any means, but I think it is interesting how well it can do already. The blue line here shows the tradeoff you get between student accuracy on some downstream tasks. This is the GSM8k simple math questions. This shows the tradeoff between student accuracy and teacher accuracy when you sample according to different mechanisms and student accuracy when you train on the generated samples.
What we see here is if you use temperature scaling to generate noisier samples, those samples are still helpful for the student model to train on. Increasing the noise to the point where the teacher gets 50% accuracy still provides pretty useful samples. When you sample with this alternative approach, you're still paying a relatively large hit to teacher accuracy. This is not something you'd want to do, but I think it's interesting that for a relatively small change in teacher accuracy, you can make the student perform no better than the original student baseline did. This is essentially-- with 70% the drop of a third or a quarter or so in teacher accuracy, you can get samples that do not help the student at all. That's interesting because you're still doing quite a bit better. You're still solving the problems as far as the teacher's concerned, but your generated samples are not very useful.
It's also expected to look at what these samples look like. I won't go over it too much here. These are samples generated at the 80% accuracy level for the teacher. This is right around here. This decreases the value of the samples from 62% of student accuracy to about 45%. In doing so, it's interesting because the traces are mostly reasonable, but you get funny things like XML-RPC codes when talking about middle school students and money problems of middle school students. What you get are mostly good-looking traces with a few adversarial examples thrown in there naturally out of this process that degrades student performance when you train on it.
I'm sure you can circumvent this if you want to write filters after the fact or sample after the fact. If you do something naive, it doesn't. It degrades student performance substantially, sometimes in ways that are a little bit obvious, but sometimes in ways that are really subtle, that you can't quite see. It just somehow changes the samples to be not very useful. To conclude with, I don't think this is something you want to do yet. It involves two more forward passes, even in a smaller proxy model to generate samples from your model. I think it's conceptually interesting that we can start to generate samples that target some secondary objective when they are used to train a student model from those samples.
The interesting thing is that this can be computed efficiently whenever your downstream loss involves the difference of two-- the difference of losses between the updated and the non-updated proxy model. More information is available here. I think it's more an interesting project at this point, an initial exploration, rather than something you'd actually want to use in practice. Nonetheless, I find it cool that it's possible to generate samples that have this property to have a very different tradeoff in terms of how useful they are for training. Thanks very much. [Applause]
Now there are obvious competitive implications for doing this. I think companies want to release closed-source models without being able to have open-source variants that are as good as them. There are also some real safety implications too. I think there's value in at least a period of having closed-source models that are at the frontier ahead of open-source models to understand their capabilities. I would rather not live in a world where every time a closed-source model is released, we can easily clone it with an open-source variant. That's my personal belief. Beyond that, I just think it'd be interesting to think about, can you create samples that somehow aren't very useful for training on? That's the goal of antidistillation sampling.
Here is the setting. We have a teacher model, and we can draw samples from the teacher model. The teacher model specifies some distribution over next tokens given previous tokens, and we can sample according to it. All you can get actually as a third party is just samples from this model. We also have a student model, but we actually don't know quite what student we're going to actually—someone downstream might be using, so we're going to think of it actually as a proxy model. We're going to have some proxy student model, and we're going to think of this model as being trained on the samples from the teacher. So whenever we take a sample from the teacher model 𝑥𝑡+1, we would then actually take a small gradient step on that token in the proxy model. That's how we update these proxy model parameters. Finally, we have some downstream loss, ℒ, which is a function of our proxy model. This measures how good the proxy model is. It measures how good it is on some, in this case, differentiable downstream task, like maybe how good its likelihood it is in a large corpus of data that we think corresponds to something like the task at hand.
Now what we want to do, and the goal of antidistillation sampling is really quite simple. We want to sample from the teacher model but do it in a way that mixes together two different objectives. On the one hand, you want to sample terms that have a high likelihood under the teacher model, so that's the first term. That's just normal temperature-based sampling as you would normally have when you're sampling any LLM. The second term is the interesting one here. It trades off between, again, tokens that are likely under the teacher model with tokens that cause the loss of this proxy model to actually increase when you train on that one token. Just again, going back to the notation here, 𝜃+𝑃 is the updated proxy model. We want the loss of the updated model to be larger than the loss of the original model. You try to minimize the loss normally, so we want the loss to be larger. We want this difference here to be positive. Therefore, we're going to sample tokens that tend to favor increases in loss for the proxy model.
I want to make the point here that doing this naively is not really practical. This loss ℒ is pretty complex to evaluate. It takes some time. Think of it as the loss of the student model on some very large corpus. That loss depends on the token you sample. In theory, you would have to evaluate this loss for every possible next token every time you want to generate a single sample from your model on some large downstream corpus. I don't have the time to go through it; this is the fun math part of this paper. It turns out that using some clever tricks coming from basically just differentiation (they're not that clever, but they were fun to work on) it turns out you can evaluate that thing efficiently for all possible next tokens, so for all possible sampling points. You can find the next sample that maximizes that tradeoff using just one forward pass like you would normally have to in the teacher model and two forward passes in the proxy model.
All right, so I won't go over this. It involves just the directional derivatives and finite difference approximations done in two different ways, which lets you do this efficiently. OK, that's the approach. That’s actually all I'll have here. I'll have to defer it to the paper for the actual details, but it's relatively cheap to compute this. The question is, does it work? The answer is kind of. I need to make the big point here, this is not a practical approach yet by any means, but I think it is interesting how well it can do already. The blue line here shows the tradeoff you get between student accuracy on some downstream tasks. This is the GSM8k simple math questions. This shows the tradeoff between student accuracy and teacher accuracy when you sample according to different mechanisms and student accuracy when you train on the generated samples.
What we see here is if you use temperature scaling to generate noisier samples, those samples are still helpful for the student model to train on. Increasing the noise to the point where the teacher gets 50% accuracy still provides pretty useful samples. When you sample with this alternative approach, you're still paying a relatively large hit to teacher accuracy. This is not something you'd want to do, but I think it's interesting that for a relatively small change in teacher accuracy, you can make the student perform no better than the original student baseline did. This is essentially-- with 70% the drop of a third or a quarter or so in teacher accuracy, you can get samples that do not help the student at all. That's interesting because you're still doing quite a bit better. You're still solving the problems as far as the teacher's concerned, but your generated samples are not very useful.
It's also expected to look at what these samples look like. I won't go over it too much here. These are samples generated at the 80% accuracy level for the teacher. This is right around here. This decreases the value of the samples from 62% of student accuracy to about 45%. In doing so, it's interesting because the traces are mostly reasonable, but you get funny things like XML-RPC codes when talking about middle school students and money problems of middle school students. What you get are mostly good-looking traces with a few adversarial examples thrown in there naturally out of this process that degrades student performance when you train on it.
I'm sure you can circumvent this if you want to write filters after the fact or sample after the fact. If you do something naive, it doesn't. It degrades student performance substantially, sometimes in ways that are a little bit obvious, but sometimes in ways that are really subtle, that you can't quite see. It just somehow changes the samples to be not very useful. To conclude with, I don't think this is something you want to do yet. It involves two more forward passes, even in a smaller proxy model to generate samples from your model. I think it's conceptually interesting that we can start to generate samples that target some secondary objective when they are used to train a student model from those samples.
The interesting thing is that this can be computed efficiently whenever your downstream loss involves the difference of two-- the difference of losses between the updated and the non-updated proxy model. More information is available here. I think it's more an interesting project at this point, an initial exploration, rather than something you'd actually want to use in practice. Nonetheless, I find it cool that it's possible to generate samples that have this property to have a very different tradeoff in terms of how useful they are for training. Thanks very much. [Applause]