|
15 | 15 | from git import Head, Repo |
16 | 16 | from typing_extensions import Any, Callable, Counter |
17 | 17 |
|
18 | | -from patchwork.common.utils.dependency import chromadb |
19 | 18 | from patchwork.logger import logger |
20 | 19 | from patchwork.managed_files import HOME_FOLDER |
21 | 20 |
|
@@ -129,68 +128,6 @@ def count_openai_tokens(code: str): |
129 | 128 | return len(_ENCODING.encode(code)) |
130 | 129 |
|
131 | 130 |
|
132 | | -def get_vector_db_path() -> str: |
133 | | - CHROMA_DB_PATH = HOME_FOLDER / "chroma.db" |
134 | | - if CHROMA_DB_PATH: |
135 | | - return str(CHROMA_DB_PATH) |
136 | | - else: |
137 | | - return ".chroma.db" |
138 | | - |
139 | | - |
140 | | -def openai_embedding_model( |
141 | | - inputs: dict, |
142 | | -) -> "chromadb.api.types.EmbeddingFunction"["chromadb.api.types.Documents"] | None: |
143 | | - model = inputs.get(openai_embedding_model.__name__) |
144 | | - if model is None: |
145 | | - return None |
146 | | - |
147 | | - api_key = inputs.get("openai_api_key") |
148 | | - if api_key is None: |
149 | | - raise ValueError("Missing required input data: 'openai_api_key'") |
150 | | - |
151 | | - return chromadb().utils.embedding_functions.OpenAIEmbeddingFunction( |
152 | | - api_key=api_key, |
153 | | - model_name=model, |
154 | | - ) |
155 | | - |
156 | | - |
157 | | -def huggingface_embedding_model( |
158 | | - inputs: dict, |
159 | | -) -> "chromadb.api.types.EmbeddingFunction"["chromadb.api.types.Documents"] | None: |
160 | | - model = inputs.get(huggingface_embedding_model.__name__) |
161 | | - if model is None: |
162 | | - return None |
163 | | - |
164 | | - api_key = inputs.get("openai_api_key") or inputs.get("huggingface_api_key") |
165 | | - if api_key is None: |
166 | | - raise ValueError("Missing required input data: 'openai_api_key' or 'huggingface_api_key'") |
167 | | - |
168 | | - return chromadb().utils.embedding_functions.HuggingFaceEmbeddingFunction( |
169 | | - api_key=api_key, |
170 | | - model_name=model, |
171 | | - ) |
172 | | - |
173 | | - |
174 | | -_EMBEDDING_FUNCS = [openai_embedding_model, huggingface_embedding_model] |
175 | | - |
176 | | -_EMBEDDING_TO_API_KEY_NAME: dict[ |
177 | | - str, Callable[[dict], "chromadb.api.type.EmbeddingFunction"["chromadb.api.types.Documents"] | None] |
178 | | -] = {func.__name__: func for func in _EMBEDDING_FUNCS} |
179 | | - |
180 | | - |
181 | | -def get_embedding_function(inputs: dict) -> "chromadb.api.types.EmbeddingFunction"["chromadb.api.types.Documents"]: |
182 | | - embedding_function = next( |
183 | | - (func(inputs) for input_key, func in _EMBEDDING_TO_API_KEY_NAME.items() if input_key in inputs.keys()), |
184 | | - None, |
185 | | - ) |
186 | | - if embedding_function is None: |
187 | | - raise ValueError( |
188 | | - f"Must specify an embedding model. Available options: {list(_EMBEDDING_TO_API_KEY_NAME.keys())}" |
189 | | - ) |
190 | | - |
191 | | - return embedding_function |
192 | | - |
193 | | - |
194 | 131 | def get_current_branch(repo: Repo) -> Head: |
195 | 132 | remote = repo.remote("origin") |
196 | 133 | if repo.head.is_detached: |
|
0 commit comments