@@ -139,6 +139,7 @@ struct slot_params {
139
139
140
140
json input_prefix;
141
141
json input_suffix;
142
+ json extra_context;
142
143
};
143
144
144
145
struct server_slot {
@@ -170,6 +171,7 @@ struct server_slot {
170
171
171
172
// when a task is submitted, we first tokenize the prompt and store it here
172
173
std::vector<llama_token> prompt_tokens;
174
+ std::vector<llama_token> extra_tokens;
173
175
174
176
std::string generated_text;
175
177
std::vector<llama_token> cache_tokens;
@@ -800,7 +802,7 @@ struct server_context {
800
802
int slot_prompt_len = slot_prompt.size ();
801
803
802
804
// length of the Longest Common Prefix between the current slot's prompt and the input prompt
803
- int lcp_len = common_part (slot_prompt, prompt);
805
+ int lcp_len = longest_common_prefix (slot_prompt, prompt);
804
806
805
807
// fraction of the common substring length compared to the current slot's prompt length
806
808
similarity = static_cast <float >(lcp_len) / slot_prompt_len;
@@ -906,8 +908,26 @@ struct server_context {
906
908
}
907
909
908
910
// infill
909
- slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
910
- slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
911
+ slot.params .input_prefix = json_value (data, " input_prefix" , default_params.input_prefix );
912
+ slot.params .input_suffix = json_value (data, " input_suffix" , default_params.input_suffix );
913
+ slot.params .extra_context = json_value (data, " extra_context" , default_params.extra_context );
914
+
915
+ SLT_DBG (slot, " extra_context chunks: %d\n " , (int ) slot.params .extra_context .size ());
916
+ for (const auto & chunk : slot.params .extra_context ) {
917
+ // { "text": string, "filename": string }
918
+ if (!chunk.contains (" text" ) || !chunk[" text" ].is_string ()) {
919
+ send_error (task, " extra_context chunk must contain a \" text\" field with a string value" , ERROR_TYPE_INVALID_REQUEST);
920
+ return false ;
921
+ }
922
+
923
+ // filename is optional
924
+ if (chunk.contains (" filename" ) && !chunk[" filename" ].is_string ()) {
925
+ send_error (task, " extra_context chunk's \" filename\" field must be a string" , ERROR_TYPE_INVALID_REQUEST);
926
+ return false ;
927
+ }
928
+
929
+ SLT_DBG (slot, " extra_context chunk in file '%s':\n %s\n " , chunk.value (" filename" , " " ).c_str (), chunk.value (" text" , " " ).c_str ());
930
+ }
911
931
912
932
// get prompt
913
933
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
@@ -1934,13 +1954,66 @@ struct server_context {
1934
1954
} break ;
1935
1955
case SERVER_TASK_CMPL_TYPE_INFILL:
1936
1956
{
1957
+ // use FIM repo-level pattern:
1958
+ // ref: https://arxiv.org/pdf/2409.12186
1959
+ //
1960
+ // [FIM_REP]myproject
1961
+ // [FIM_SEP]filename0
1962
+ // extra chunk 0
1963
+ // [FIM_SEP]filename1
1964
+ // extra chunk 1
1965
+ // ...
1966
+ // [FIM_SEP]filename
1967
+ // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]
1968
+ //
1937
1969
auto prefix_tokens = tokenize (slot.params .input_prefix , false , false );
1938
1970
auto suffix_tokens = tokenize (slot.params .input_suffix , false , false );
1939
1971
1940
- // for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
1941
- const int n_suffix_take = std::min<int >(suffix_tokens.size (), n_batch/4 );
1972
+ slot.extra_tokens .clear ();
1973
+ if (llama_token_fim_rep (model) != LLAMA_TOKEN_NULL) {
1974
+ static const auto k_fim_repo = tokenize (" myproject\n " , false , false );
1975
+
1976
+ slot.extra_tokens .push_back (llama_token_fim_rep (model));
1977
+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_repo.begin (), k_fim_repo.end ());
1978
+ }
1979
+
1980
+ for (const auto & chunk : slot.params .extra_context ) {
1981
+ // { "text": string, "filename": string }
1982
+ const std::string text = chunk.value (" text" , " " );
1983
+ const std::string filename = chunk.value (" filename" , " tmp" );
1984
+
1985
+ if (llama_token_fim_sep (model) != LLAMA_TOKEN_NULL) {
1986
+ const auto k_fim_file = tokenize (filename + " \n " , false , false );
1987
+
1988
+ slot.extra_tokens .insert (slot.extra_tokens .end (), llama_token_fim_sep (model));
1989
+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_file.begin (), k_fim_file.end ());
1990
+ } else {
1991
+ // chunk separator in binary form to avoid confusing the AI
1992
+ static const char k_chunk_prefix_str[] = {0x0a , 0x0a , 0x2d , 0x2d , 0x2d , 0x20 , 0x73 , 0x6e , 0x69 , 0x70 , 0x70 , 0x65 , 0x74 , 0x20 , 0x2d , 0x2d , 0x2d , 0x0a , 0x0a , 0x00 };
1993
+ static const auto k_chunk_prefix_tokens = tokenize (k_chunk_prefix_str, false , false );
1994
+
1995
+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_chunk_prefix_tokens.begin (), k_chunk_prefix_tokens.end ());
1996
+ }
1997
+
1998
+ const auto chunk_tokens = tokenize (text, false , false );
1999
+ slot.extra_tokens .insert (slot.extra_tokens .end (), chunk_tokens.begin (), chunk_tokens.end ());
2000
+ }
2001
+
2002
+ if (llama_token_fim_sep (model) != LLAMA_TOKEN_NULL) {
2003
+ // TODO: current filename
2004
+ static const auto k_fim_file = tokenize (" filename\n " , false , false );
2005
+
2006
+ slot.extra_tokens .insert (slot.extra_tokens .end (), llama_token_fim_sep (model));
2007
+ slot.extra_tokens .insert (slot.extra_tokens .end (), k_fim_file.begin (), k_fim_file.end ());
2008
+ }
2009
+
2010
+ // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?)
2011
+ const int n_suffix_take = std::min<int >(suffix_tokens.size (), (n_batch)/4 );
1942
2012
const int n_prefix_take = std::min<int >(prefix_tokens.size (), (n_batch - 3 ) - n_suffix_take);
1943
2013
2014
+ // fill the rest of the context with extra chunks
2015
+ const int n_extra_take = std::min<int >(std::max<int >(0 , slot.n_ctx - (n_batch) - 2 *slot.n_predict ), slot.extra_tokens .size ());
2016
+
1944
2017
prefix_tokens.erase (prefix_tokens.begin (), prefix_tokens.begin () + prefix_tokens.size () - n_prefix_take);
1945
2018
suffix_tokens.resize (n_suffix_take);
1946
2019
@@ -1954,6 +2027,11 @@ struct server_context {
1954
2027
embd_inp.insert (embd_inp.begin (), llama_token_bos (model));
1955
2028
}
1956
2029
2030
+ SLT_DBG (slot, " extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n " , slot.n_ctx , n_extra_take, (int ) slot.extra_tokens .size ());
2031
+
2032
+ // put the extra context before the FIM prefix
2033
+ embd_inp.insert (embd_inp.begin (), slot.extra_tokens .end () - n_extra_take, slot.extra_tokens .end ());
2034
+
1957
2035
embd_inp.insert (embd_inp.end (), embd_end.begin (), embd_end.end ());
1958
2036
embd_inp.push_back (llama_token_fim_mid (model));
1959
2037
@@ -2012,7 +2090,7 @@ struct server_context {
2012
2090
}
2013
2091
slot.params .n_keep = std::min (slot.n_ctx - 4 , slot.params .n_keep );
2014
2092
2015
- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2093
+ // if input prompt is too big, truncate it
2016
2094
if (slot.n_prompt_tokens >= slot.n_ctx ) {
2017
2095
const int n_left = slot.n_ctx - slot.params .n_keep ;
2018
2096
@@ -2042,12 +2120,82 @@ struct server_context {
2042
2120
2043
2121
if (slot.params .cache_prompt ) {
2044
2122
// reuse any previously computed tokens that are common with the new prompt
2045
- slot.n_past = common_part (slot.cache_tokens , prompt_tokens);
2123
+ slot.n_past = longest_common_prefix (slot.cache_tokens , prompt_tokens);
2046
2124
2047
2125
// push the prompt into the sampling context (do not apply grammar)
2048
2126
for (int i = 0 ; i < slot.n_past ; ++i) {
2049
2127
common_sampler_accept (slot.smpl , slot.cache_tokens [i], false );
2050
2128
}
2129
+
2130
+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
2131
+ if (params.n_cache_reuse > 0 ) {
2132
+ size_t head_c = slot.n_past ; // cache
2133
+ size_t head_p = slot.n_past ; // current prompt
2134
+
2135
+ SLT_DBG (slot, " trying to reuse chunks with size > %d, slot.n_past = %d\n " , params.n_cache_reuse , slot.n_past );
2136
+
2137
+ while (head_c < slot.cache_tokens .size () &&
2138
+ head_p < prompt_tokens.size ()) {
2139
+ if (llama_token_is_control (model, slot.cache_tokens [head_c]) &&
2140
+ slot.cache_tokens [head_c] != llama_token_fim_rep (model) &&
2141
+ slot.cache_tokens [head_c] != llama_token_fim_sep (model)) {
2142
+ break ;
2143
+ }
2144
+
2145
+ if (llama_token_is_control (model, prompt_tokens[head_p]) &&
2146
+ prompt_tokens[head_p] != llama_token_fim_rep (model) &&
2147
+ prompt_tokens[head_p] != llama_token_fim_sep (model)) {
2148
+ break ;
2149
+ }
2150
+
2151
+ size_t n_match = 0 ;
2152
+
2153
+ while (head_c + n_match < slot.cache_tokens .size () &&
2154
+ head_p + n_match < prompt_tokens.size () &&
2155
+ slot.cache_tokens [head_c + n_match] == prompt_tokens[head_p + n_match]) {
2156
+ if (llama_token_is_control (model, slot.cache_tokens [head_c + n_match]) &&
2157
+ slot.cache_tokens [head_c + n_match] != llama_token_fim_rep (model) &&
2158
+ slot.cache_tokens [head_c + n_match] != llama_token_fim_sep (model)) {
2159
+ break ;
2160
+ }
2161
+
2162
+ if (llama_token_is_control (model, prompt_tokens[head_p + n_match]) &&
2163
+ prompt_tokens[head_p + n_match] != llama_token_fim_rep (model) &&
2164
+ prompt_tokens[head_p + n_match] != llama_token_fim_sep (model)) {
2165
+ break ;
2166
+ }
2167
+
2168
+ n_match++;
2169
+ }
2170
+
2171
+ if (n_match >= (size_t ) params.n_cache_reuse ) {
2172
+ SLT_DBG (slot, " reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n " , n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2173
+ // for (size_t i = head_p; i < head_p + n_match; i++) {
2174
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2175
+ // }
2176
+
2177
+ const int64_t kv_shift = (int64_t ) head_p - (int64_t ) head_c;
2178
+
2179
+ llama_kv_cache_seq_rm (ctx, slot.id + 1 , head_p, head_c);
2180
+ llama_kv_cache_seq_add (ctx, slot.id + 1 , head_c, -1 , kv_shift);
2181
+
2182
+ for (size_t i = 0 ; i < n_match; i++) {
2183
+ slot.cache_tokens [head_p + i] = slot.cache_tokens [head_c + i];
2184
+
2185
+ common_sampler_accept (slot.smpl , slot.cache_tokens [head_p + i], false );
2186
+
2187
+ slot.n_past ++;
2188
+ }
2189
+
2190
+ head_c += n_match;
2191
+ head_p += n_match;
2192
+ } else {
2193
+ head_c += 1 ;
2194
+ }
2195
+ }
2196
+
2197
+ SLT_DBG (slot, " after context reuse, new slot.n_past = %d\n " , slot.n_past );
2198
+ }
2051
2199
}
2052
2200
}
2053
2201
@@ -3257,6 +3405,7 @@ int main(int argc, char ** argv) {
3257
3405
3258
3406
ctx_server.queue_tasks .on_new_task (std::bind (
3259
3407
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
3408
+
3260
3409
ctx_server.queue_tasks .on_update_slots (std::bind (
3261
3410
&server_context::update_slots, &ctx_server));
3262
3411
0 commit comments