diff --git a/LLama/Extensions/DictionaryExtensions.cs b/LLama/Extensions/DictionaryExtensions.cs new file mode 100644 index 000000000..e5a27d6d5 --- /dev/null +++ b/LLama/Extensions/DictionaryExtensions.cs @@ -0,0 +1,14 @@ +using System.Collections.Generic; + +namespace LLama.Extensions +{ + internal static class DictionaryExtensions + { +#if NETSTANDARD2_0 + public static TValue GetValueOrDefault(this IReadOnlyDictionary dictionary, TKey key, TValue defaultValue) + { + return dictionary.TryGetValue(key, out var value) ? value : defaultValue; + } +#endif + } +} diff --git a/LLama/Extensions/IEnumerableExtensions.cs b/LLama/Extensions/IEnumerableExtensions.cs new file mode 100644 index 000000000..ebc234be0 --- /dev/null +++ b/LLama/Extensions/IEnumerableExtensions.cs @@ -0,0 +1,21 @@ +using System.Collections.Generic; +using System.Linq; + +namespace LLama.Extensions +{ + internal static class IEnumerableExtensions + { +#if NETSTANDARD2_0 + public static IEnumerable TakeLast(this IEnumerable source, int count) + { + var list = source.ToList(); + + if (count >= list.Count) + return list; + + list.RemoveRange(0, list.Count - count); + return list; + } +#endif + } +} diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index d83baf3da..fbb2107c8 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -355,36 +355,41 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, bool penalizeNL = true) { - var n_vocab = _ctx.VocabCount; var logits = _ctx.GetLogits(); // Apply params.logit_bias map - if(logitBias is not null) + if (logitBias is not null) { foreach (var (key, value) in logitBias) - { logits[key] += value; - } } - var candidates = new LLamaTokenData[n_vocab]; - for (llama_token token_id = 0; token_id < n_vocab; token_id++) - candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); - LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates); - - // Apply penalties - float nl_logit = logits[NativeApi.llama_token_nl()]; - int lastTokensCount = lastTokens.Count(); - var last_n_repeat = Math.Min(Math.Min(lastTokensCount, repeatLastTokensCount), ContextSize); - SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, - lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), - (ulong)last_n_repeat, repeatPenalty); - SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, - lastTokens.Skip(lastTokensCount - last_n_repeat).ToArray(), - (ulong)last_n_repeat, alphaFrequency, alphaPresence); + // Save the newline logit value + var nl_token = NativeApi.llama_token_nl(); + var nl_logit = logits[nl_token]; + + // Convert logits into token candidates + var candidates_p = LLamaTokenDataArray.Create(logits); + + // Extract most recently returned tokens + var last_n_repeat = Math.Min(ContextSize, repeatLastTokensCount); + var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); + + // Apply penalties to candidates + SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p, last_n_array, repeatPenalty); + SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates_p, last_n_array, alphaFrequency, alphaPresence); + + // Restore newline token logit value if necessary if (!penalizeNL) { - logits[NativeApi.llama_token_nl()] = nl_logit; + var candidatesSpan = candidates_p.data.Span; + for (var i = 0; i < candidates_p.data.Length; i++) + { + ref var item = ref candidatesSpan[i]; + if (item.id == nl_token) + item.logit = nl_logit; + } + candidates_p.sorted = false; } return candidates_p; diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index add3b7030..7a2965ed8 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -2,6 +2,8 @@ using System.Buffers; using System.Runtime.InteropServices; +using llama_token = System.Int32; + namespace LLama.Native { /// @@ -15,9 +17,9 @@ public struct LLamaTokenDataArray public readonly Memory data; /// - /// Indicates if `data` is sorted + /// Indicates if `data` is sorted by logits in descending order. If this is false the token data is in _no particular order_. /// - public readonly bool sorted; + public bool sorted; /// /// Create a new LLamaTokenDataArray @@ -29,6 +31,20 @@ public LLamaTokenDataArray(Memory tokens, bool isSorted = false) data = tokens; sorted = isSorted; } + + /// + /// Create a new LLamaTokenDataArray, copying the data from the given logits + /// + /// + /// + public static LLamaTokenDataArray Create(ReadOnlySpan logits) + { + var candidates = new LLamaTokenData[logits.Length]; + for (var token_id = 0; token_id < logits.Length; token_id++) + candidates[token_id] = new LLamaTokenData(token_id, logits[token_id], 0.0f); + + return new LLamaTokenDataArray(candidates); + } } /// diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs index d67ac9a8c..56771579b 100644 --- a/LLama/Native/SamplingApi.cs +++ b/LLama/Native/SamplingApi.cs @@ -25,12 +25,25 @@ public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDa /// /// /// + [Obsolete("last_tokens_size parameter is no longer needed")] public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float penalty) + { + llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty); + } + + /// + /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// + /// + /// Pointer to LLamaTokenDataArray + /// + /// + public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float penalty) { using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); using var last_tokens_handle = last_tokens.Pin(); - NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, penalty); + NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty); } /// @@ -42,12 +55,26 @@ public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, L /// /// /// + [Obsolete("last_tokens_size parameter is no longer needed")] public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) + { + llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence); + } + + /// + /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + /// + /// + /// Pointer to LLamaTokenDataArray + /// + /// + /// + public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float alpha_frequency, float alpha_presence) { using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); using var last_tokens_handle = last_tokens.Pin(); - NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, last_tokens_size, alpha_frequency, alpha_presence); + NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence); } ///