Skip to content

Commit 7cb9f3d

Browse files
Add Llama CPP
* added a Llama CPP invocation layer. * Readme section. * Tutorial notebook
1 parent 89a2522 commit 7cb9f3d

File tree

9 files changed

+942
-1
lines changed

9 files changed

+942
-1
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ For a brief overview of the various unique components in fastRAG refer to the [C
5353
<td><a href="components.md#fastrag-running-llms-with-onnx-runtime">ONNX Runtime</a></td>
5454
<td><em>Running LLMs with optimized ONNX-runtime</td>
5555
</tr>
56+
<tr>
57+
<td><a href="components.md#fastrag-running-rag-pipelines-with-llms-on-a-llama-cpp-backend">Llama-CPP</a></td>
58+
<td><em>Running RAG Pipelines with LLMs on a Llama CPP backend</td>
59+
</tr>
5660
<tr>
5761
<td colspan="2"><strong><em>Optimized Components</em></td>
5862
</tr>

components.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,37 @@ PrompterModel = PromptModel(
146146
)
147147
```
148148

149+
## fastRAG Running RAG Pipelines with LLMs on a Llama CPP backend
150+
151+
To run LLM effectively on CPUs, especially on client side machines, we offer a method for running LLMs using the [llama-cpp](https://github.com/ggerganov/llama.cpp).
152+
We recommend checking out our [tutorial notebook](examples/client_inference_with_Llama_cpp.ipynb) with all the details, including processes such as downloading GGUF models.
153+
154+
### Installation
155+
156+
Run the following command to install our dependencies:
157+
158+
```
159+
pip install -e .[llama_cpp]
160+
```
161+
162+
For more information regarding the installation process, we recommend checking out the [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) repository.
163+
164+
165+
### Loading the Model
166+
167+
Now that our model is downloaded, we can load it in our framework, by specifying the ```LlamaCPPInvocationLayer``` invocation layer.
168+
169+
```python
170+
PrompterModel = PromptModel(
171+
model_name_or_path= "models/marcoroni-7b-v3.Q4_K_M.gguf",
172+
invocation_layer_class=LlamaCPPInvocationLayer,
173+
model_kwargs= dict(
174+
max_new_tokens=100
175+
)
176+
)
177+
```
178+
179+
149180
## Optimized Embedding Models
150181

151182
Bi-encoder Embedders are key components of Retrieval Augmented Generation pipelines. Mainly used for indexing documents and for online re-ranking. We provide support for quantized `int8` models that have low latency and high throughput, using [`optimum-intel`](https://github.com/huggingface/optimum-intel) framework.

examples.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
| RAG pipeline with FiD generator | [:notebook_with_decorative_cover:](examples/fid_promping.ipynb) |
88
| RAG pipeline with REPLUG-based generator | [:notebook_with_decorative_cover:](examples/replug_parallel_reader.ipynb) |
99
| RAG pipeline with LLMs running on Gaudi2 |[:notebook_with_decorative_cover:](examples/inference_with_gaudi.ipynb) |
10-
| RAG pipeline with quantized LLMs running on ONNX-running backend | [:notebook_with_decorative_cover:](examples/inference_with_gaudi.ipynb) |
10+
| RAG pipeline with quantized LLMs running on ONNX-running backend | [:notebook_with_decorative_cover:](examples/rag_with_quantized_llm.ipynb) |
11+
| RAG pipeline with LLMs running on Llama-CPP backend | [:notebook_with_decorative_cover:](examples/client_inference_with_Llama_cpp.ipynb) |
1112
| Optimized and quantized Embeddings models for retrieval and ranking | [:notebook_with_decorative_cover:](examples/optimized-embeddings.ipynb) |
1213
| RAG pipeline with PLAID index and ColBERT Ranker | [:notebook_with_decorative_cover:](examples/plaid_colbert_pipeline.ipynb) |
1314
| RAG pipeline with Qdrant index | [:notebook_with_decorative_cover:](examples/qdrant_document_store.ipynb) |

examples/client_inference_with_Llama_cpp.ipynb

Lines changed: 747 additions & 0 deletions
Large diffs are not rendered by default.

fastrag/prompters/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@
33
from fastrag.prompters.invocation_layers.gaudi_hugging_face_inference import (
44
GaudiHFLocalInvocationLayer,
55
)
6+
from fastrag.prompters.invocation_layers.llama_cpp import LlamaCPPInvocationLayer
67
from fastrag.prompters.invocation_layers.ort import ORTInvocationLayer
78
from fastrag.prompters.invocation_layers.vqa import VQAHFLocalInvocationLayer
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import logging
2+
import sys
3+
from typing import Dict, List, Optional, Union
4+
5+
from haystack.lazy_imports import LazyImport
6+
from haystack.nodes.prompt.invocation_layer import PromptModelInvocationLayer
7+
from haystack.nodes.prompt.invocation_layer.hugging_face import HFLocalInvocationLayer
8+
9+
with LazyImport("Install llama_cpp using 'pip install -e .[llama_cpp]'") as llama_cpp_import:
10+
from llama_cpp import Llama
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class LlamaCPPInvocationLayer(HFLocalInvocationLayer):
16+
"""
17+
A subclass of the PromptModelInvocationLayer class. It loads a pre-trained model from Hugging Face,
18+
and loads it into an HPU device, including ad-hoc optimizations.
19+
"""
20+
21+
def __init__(
22+
self,
23+
model_name_or_path: str = "llama-model.gguf",
24+
max_length: int = 100,
25+
use_auth_token: Optional[Union[str, bool]] = None,
26+
**kwargs,
27+
):
28+
PromptModelInvocationLayer.__init__(self, model_name_or_path)
29+
30+
self.llm = Llama(model_path=model_name_or_path)
31+
self.max_length = max_length
32+
self.max_new_tokens = kwargs.get("max_new_tokens", 100)
33+
34+
# Additional properties for Invocation Layer requirements
35+
self.model_max_length = kwargs.get("model_max_length", sys.maxsize)
36+
self.generation_kwargs = kwargs
37+
38+
def _ensure_token_limit(
39+
self, prompt: Union[str, List[Dict[str, str]]]
40+
) -> Union[str, List[Dict[str, str]]]:
41+
"""Ensure that the length of the prompt and answer is within the max tokens limit of the model.
42+
If needed, truncate the prompt text so that it fits within the limit.
43+
44+
:param prompt: Prompt text to be sent to the generative model.
45+
"""
46+
model_max_length = self.model_max_length
47+
tokenized_prompt = self.llm.tokenize(bytes(prompt, "utf-8"))
48+
n_prompt_tokens = len(tokenized_prompt)
49+
n_answer_tokens = self.max_length
50+
if (n_prompt_tokens + n_answer_tokens) <= model_max_length:
51+
return prompt
52+
53+
logger.warning(
54+
"The prompt has been truncated from %s tokens to %s tokens so that the prompt length and "
55+
"answer length (%s tokens) fit within the max token limit (%s tokens). "
56+
"Shorten the prompt to prevent it from being cut off",
57+
n_prompt_tokens,
58+
max(0, model_max_length - n_answer_tokens),
59+
n_answer_tokens,
60+
model_max_length,
61+
)
62+
63+
decoded_string = self.llm.detokenize(
64+
tokenized_prompt[: model_max_length - n_answer_tokens]
65+
).decode("utf-8")
66+
return decoded_string
67+
68+
def invoke(self, *args, **kwargs):
69+
"""
70+
It takes a prompt and returns a list of generated texts using the local Hugging Face transformers model
71+
:return: A list of generated texts.
72+
73+
Note: Only kwargs relevant to Text2TextGenerationPipeline and TextGenerationPipeline are passed to
74+
Hugging Face as model_input_kwargs. Other kwargs are ignored.
75+
"""
76+
output: List[Dict[str, str]] = []
77+
stop_words = kwargs.pop("stop_words", [])
78+
79+
generated_texts = []
80+
if kwargs and "prompt" in kwargs:
81+
prompt = kwargs.pop("prompt")
82+
83+
generation_kwargs = self.generation_kwargs
84+
model_input_kwargs = {
85+
key: kwargs[key]
86+
for key in [
87+
"return_tensors",
88+
"return_text",
89+
"return_full_text",
90+
"clean_up_tokenization_spaces",
91+
"truncation",
92+
"generation_kwargs",
93+
"max_new_tokens",
94+
"num_beams",
95+
"do_sample",
96+
"num_return_sequences",
97+
"max_length",
98+
]
99+
if key in kwargs
100+
}
101+
102+
generation_kwargs.update(model_input_kwargs)
103+
model_input_kwargs = generation_kwargs
104+
105+
echo = model_input_kwargs.get("return_full_text", False)
106+
max_tokens = model_input_kwargs.get("max_new_tokens", self.max_new_tokens)
107+
108+
output = self.llm(
109+
prompt, # Prompt
110+
max_tokens=max_tokens, # Generate up to 32 tokens
111+
stop=stop_words, # Stop generating just before the model would generate a new question
112+
echo=echo, # Echo the prompt back in the output
113+
) # Generate a completion, can also call create_completion
114+
115+
generated_texts = [output["choices"][0]["text"]]
116+
117+
return generated_texts

scripts/optimizations/Llama_CPP.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Running RAG Pipelines with LLMs on a Llama CPP backend
2+
3+
To run LLM effectively on CPUs, especially on client side machines, we offer a method for running LLMs using the [llama-cpp](https://github.com/ggerganov/llama.cpp).
4+
We recommend checking out our [tutorial notebook](../../examples/client_inference_with_Llama_cpp.ipynb) with all the details, including processes such as downloading GGUF models.
5+
6+
## Installation
7+
8+
Run the following command to install our dependencies:
9+
10+
```
11+
pip install -e .[llama_cpp]
12+
```
13+
14+
For more information regarding the installation process, we recommend checking out the [llama-cpp-python](https://github.com/abetlen/llama-cpp-python) repository.
15+
16+
## Downloading GGUF models
17+
18+
In order to use LlamaCPP, download a gguf model, optimal for llama cpp inference:
19+
20+
```
21+
huggingface-cli download TheBloke/Marcoroni-7B-v3-GGUF marcoroni-7b-v3.Q4_K_M.gguf --local-dir ./models --local-dir-use-symlinks False
22+
```
23+
24+
## Loading the Model
25+
26+
Now that our model is downloaded, we can load it in our framework, by specifying the ```LlamaCPPInvocationLayer``` invocation layer.
27+
28+
```python
29+
PrompterModel = PromptModel(
30+
model_name_or_path= "models/marcoroni-7b-v3.Q4_K_M.gguf",
31+
invocation_layer_class=LlamaCPPInvocationLayer,
32+
model_kwargs= dict(
33+
max_new_tokens=100
34+
)
35+
)
36+
```

scripts/optimizations/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ Reduction in bit count leads to a model that requires less memory storage, poten
1818
| [LLM Quantization](LLM-quantization.md) | `optimum-intel` | CPU |
1919
| [Bi-encoder Quantization](embedders/README.md) | `optimum-intel` | CPU |
2020
| [Cross-encoder Quantization](reranker_quantization/quantization.md) | `neural-compressor`, `ipex` | CPU |
21+
| [LlamaCPP LLMs](Llama_CPP.md) | `llama_cpp` | CPU |

setup.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ intel =
7373
intel-extension-for-transformers
7474
optimum[neural-compressor]
7575

76+
llama_cpp =
77+
llama-cpp-python
78+
7679
[flake8]
7780
ignore = E501
7881
max-line-length = 100

0 commit comments

Comments
 (0)