Skip to content

Commit d18a79e

Browse files
committed
llama_batch_ext_init with ctx
1 parent 1434c2c commit d18a79e

File tree

41 files changed

+124
-113
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+124
-113
lines changed

common/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,7 +1016,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10161016
}
10171017

10181018
if (llama_model_has_encoder(model)) {
1019-
auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), tmp.size(), 0, 0, true);
1019+
auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), tmp.size(), 0, 0, true);
10201020
llama_encode_ext(lctx, batch.get());
10211021
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
10221022
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
@@ -1026,7 +1026,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10261026
tmp.push_back(decoder_start_token_id);
10271027
}
10281028
if (llama_model_has_decoder(model)) {
1029-
auto batch = llama_batch_ext_ptr::init_from_text(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true);
1029+
auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true);
10301030
llama_decode_ext(lctx, batch.get());
10311031
}
10321032
llama_kv_self_clear(lctx);

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init(
2323
auto * result = new common_speculative {
2424
/* .ctx = */ ctx_dft,
2525
/* .smpl = */ nullptr,
26-
/* .batch = */ llama_batch_ext_ptr(llama_batch_ext_init(llama_n_batch(ctx_dft), 1)),
26+
/* .batch = */ llama_batch_ext_ptr(ctx_dft),
2727
/* .prompt = */ {},
2828
};
2929

examples/batched-bench/batched-bench.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ int main(int argc, char ** argv) {
5959

6060
const int32_t n_kv_max = llama_n_ctx(ctx);
6161

62-
llama_batch_ext * batch = llama_batch_ext_init(n_kv_max, 1);
62+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
6363

6464
// decode in batches of ctx_params.n_batch tokens
6565
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {

examples/batched/batched.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ int main(int argc, char ** argv) {
102102

103103
// create a llama_batch
104104
// we use this object to submit token data for decoding
105-
llama_batch_ext * batch = llama_batch_ext_init(std::max(tokens_list.size(), (size_t) n_parallel), n_parallel);
105+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
106106

107107
std::vector<llama_seq_id> seq_ids(n_parallel, 0);
108108
for (int32_t i = 0; i < n_parallel; ++i) {

examples/cvector-generator/cvector-generator.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) {
343343

344344
static bool get_hidden_layers(llama_context * ctx, std::vector<llama_token> & tokens) {
345345
llama_kv_self_clear(ctx);
346-
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
346+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true);
347347
if (llama_decode_ext(ctx, batch.get())) {
348348
fprintf(stderr, "%s : failed to eval\n", __func__);
349349
return false;

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ int main(int argc, char ** argv) {
167167

168168
// initialize batch
169169
const int n_prompts = prompts.size();
170-
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
170+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
171171

172172
// count number of embeddings
173173
int n_embd_count = 0;

examples/eval-callback/eval-callback.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ static bool run(llama_context * ctx, const common_params & params) {
134134

135135
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
136136

137-
auto batch = llama_batch_ext_ptr::init_from_text(tokens.data(), tokens.size(), 0, 0, true);
137+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true);
138138
if (llama_decode_ext(ctx, batch.get())) {
139139
LOG_ERR("%s : failed to eval\n", __func__);
140140
return false;

examples/gritlm/gritlm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
1414
const llama_model * model = llama_get_model(ctx);
1515
const llama_vocab * vocab = llama_model_get_vocab(model);
1616

17-
llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
17+
llama_batch_ext_ptr batch(ctx);
1818

1919
for (uint64_t i = 0; i < sentences.size(); i++) {
2020
batch.clear();
@@ -105,7 +105,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
105105
llama_set_embeddings(ctx, false);
106106
llama_set_causal_attn(ctx, true);
107107

108-
llama_batch_ext_ptr batch(llama_batch_ext_init(llama_n_batch(ctx), 1));
108+
llama_batch_ext_ptr batch(ctx);
109109

110110
std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
111111
int32_t i_current_token = 0;

examples/imatrix/imatrix.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
497497
// clear the KV cache
498498
llama_kv_self_clear(ctx);
499499

500-
llama_batch_ext * batch = llama_batch_ext_init(n_batch, 1);
500+
llama_batch_ext * batch = llama_batch_ext_init(ctx);
501501

502502
for (int j = 0; j < num_batches; ++j) {
503503
const int batch_start = start + j * n_batch;

examples/infill/infill.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ int main(int argc, char ** argv) {
353353

354354
LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str());
355355

356-
auto batch = llama_batch_ext_ptr::init_from_text(&embd[i], n_eval, n_past, 0, true);
356+
auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true);
357357
if (llama_decode_ext(ctx, batch.get())) {
358358
LOG_ERR("%s : failed to eval\n", __func__);
359359
return 1;

0 commit comments

Comments
 (0)