-
Notifications
You must be signed in to change notification settings - Fork 12.7k
sampling : add XTC sampler #9742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 47 commits
89640b0
9455194
db54ac5
41e1665
d9c9203
f2a2a61
4f8e55b
6d94ba2
49cd211
899e073
74f657c
59e8e63
63e60de
094caea
39940e5
4c44e3d
dbe9ef7
98b204c
8110f78
81a0c26
09bc6d5
c19fb26
6feb6b3
d0b1053
ed535bb
37e02e3
ba29d31
2107882
f7a383f
72db625
882a603
3968369
acada1a
dfe587a
9c43a01
68557eb
ea85a51
cca842f
ea62e65
44bbd63
a3e6522
dfef2c4
436a991
3613a6d
17ad143
2be814a
28d2cff
3496f58
050eb7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1059,6 +1059,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa | |
}; | ||
} | ||
|
||
// xtc | ||
|
||
struct llama_sampler_xtc { | ||
const float probability; | ||
const float threshold; | ||
const size_t min_keep; | ||
|
||
const uint32_t seed; | ||
uint32_t seed_cur; | ||
|
||
std::mt19937 rng; | ||
}; | ||
|
||
static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { | ||
return "xtc"; | ||
} | ||
|
||
static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { | ||
auto * ctx = (llama_sampler_xtc *) smpl->ctx; | ||
|
||
if (ctx->probability <= 0.0f | ||
|| ctx->threshold > 0.5f | ||
|| cur_p->size < 2) { | ||
return; | ||
} | ||
|
||
std::uniform_real_distribution<float> distribution(0.0f, 1.0f); | ||
float chance = distribution(ctx->rng); | ||
if (chance > ctx->probability) return; | ||
slaren marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// in case it's not sorted/recalculated yet | ||
llama_sampler_softmax_impl(cur_p); | ||
|
||
int pos_last = 0; | ||
|
||
for (size_t i = 0; i < cur_p->size; ++i) { | ||
if (cur_p->data[i].p >= ctx->threshold) { | ||
pos_last = i; | ||
} else break; | ||
} | ||
|
||
if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { | ||
cur_p->data += pos_last; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may potentially break 3rd party code that expects this pointer to be unchanged (eg. to free it after sampling). I don't think this is necessarily a problem, but we should make it clear that this pointer may be changed by the samplers, and applications should not rely on it being unchanged. |
||
cur_p->size = cur_p->size - pos_last; | ||
MaggotHATE marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
} | ||
|
||
static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { | ||
const auto * ctx = (const llama_sampler_xtc *) smpl->ctx; | ||
auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed); | ||
|
||
// copy the state | ||
{ | ||
auto * result_ctx = (llama_sampler_xtc *) result->ctx; | ||
|
||
result_ctx->rng = ctx->rng; | ||
} | ||
|
||
return result; | ||
} | ||
|
||
static void llama_sampler_xtc_free(struct llama_sampler * smpl) { | ||
delete (llama_sampler_xtc *) smpl->ctx; | ||
} | ||
|
||
static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of this function? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. AFAIK this is necessary to properly reset seed and maintain repeatability, as recommended by @slaren earlier. |
||
auto * ctx = (llama_sampler_xtc *) smpl->ctx; | ||
ctx->seed_cur = get_rng_seed(ctx->seed); | ||
ctx->rng.seed(ctx->seed_cur); | ||
} | ||
|
||
static struct llama_sampler_i llama_sampler_xtc_i = { | ||
/* .name = */ llama_sampler_xtc_name, | ||
/* .accept = */ nullptr, | ||
/* .apply = */ llama_sample_xtc_apply, | ||
/* .reset = */ llama_sampler_xtc_reset, | ||
/* .clone = */ llama_sampler_xtc_clone, | ||
/* .free = */ llama_sampler_xtc_free, | ||
}; | ||
|
||
struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { | ||
auto seed_cur = get_rng_seed(seed); | ||
return new llama_sampler { | ||
/* .iface = */ &llama_sampler_xtc_i, | ||
/* .ctx = */ new llama_sampler_xtc { | ||
/* .probability = */ p, | ||
/* .threshold = */ t, | ||
/* .min_keep = */ min_keep, | ||
/* .seed = */ seed, | ||
/* .seed_cur = */ seed_cur, | ||
/* .rng = */ std::mt19937(seed_cur), | ||
}, | ||
}; | ||
} | ||
|
||
// mirostat | ||
|
||
struct llama_sampler_mirostat { | ||
|
Uh oh!
There was an error while loading. Please reload this page.