Skip to content

Commit 3b76eea

Browse files
authored
Merge pull request #2359 from danielaskdd/embedding-limit
Refact: Add Embedding Token Limit Configuration and Improve Error Handling
2 parents 9a2ddce + 8722103 commit 3b76eea

File tree

16 files changed

+433
-101
lines changed

16 files changed

+433
-101
lines changed

env.example

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -255,28 +255,40 @@ OLLAMA_LLM_NUM_CTX=32768
255255
### For OpenAI: Set to 'true' to enable dynamic dimension adjustment
256256
### For OpenAI: Set to 'false' (default) to disable sending dimension parameter
257257
### Note: Automatically ignored for backends that don't support dimension parameter (e.g., Ollama)
258-
# EMBEDDING_SEND_DIM=false
259258

260-
EMBEDDING_BINDING=ollama
261-
EMBEDDING_MODEL=bge-m3:latest
262-
EMBEDDING_DIM=1024
263-
EMBEDDING_BINDING_API_KEY=your_api_key
264-
# If LightRAG deployed in Docker uses host.docker.internal instead of localhost
265-
EMBEDDING_BINDING_HOST=http://localhost:11434
266-
267-
### OpenAI compatible (VoyageAI embedding openai compatible)
268-
# EMBEDDING_BINDING=openai
269-
# EMBEDDING_MODEL=text-embedding-3-large
270-
# EMBEDDING_DIM=3072
271-
# EMBEDDING_BINDING_HOST=https://api.openai.com/v1
259+
# Ollama embedding
260+
# EMBEDDING_BINDING=ollama
261+
# EMBEDDING_MODEL=bge-m3:latest
262+
# EMBEDDING_DIM=1024
272263
# EMBEDDING_BINDING_API_KEY=your_api_key
264+
### If LightRAG deployed in Docker uses host.docker.internal instead of localhost
265+
# EMBEDDING_BINDING_HOST=http://localhost:11434
266+
267+
### OpenAI compatible embedding
268+
EMBEDDING_BINDING=openai
269+
EMBEDDING_MODEL=text-embedding-3-large
270+
EMBEDDING_DIM=3072
271+
EMBEDDING_SEND_DIM=false
272+
EMBEDDING_TOKEN_LIMIT=8192
273+
EMBEDDING_BINDING_HOST=https://api.openai.com/v1
274+
EMBEDDING_BINDING_API_KEY=your_api_key
273275

274276
### Optional for Azure
275277
# AZURE_EMBEDDING_DEPLOYMENT=text-embedding-3-large
276278
# AZURE_EMBEDDING_API_VERSION=2023-05-15
277279
# AZURE_EMBEDDING_ENDPOINT=your_endpoint
278280
# AZURE_EMBEDDING_API_KEY=your_api_key
279281

282+
### Gemini embedding
283+
# EMBEDDING_BINDING=gemini
284+
# EMBEDDING_MODEL=gemini-embedding-001
285+
# EMBEDDING_DIM=1536
286+
# EMBEDDING_TOKEN_LIMIT=2048
287+
# EMBEDDING_BINDING_HOST=https://generativelanguage.googleapis.com
288+
# EMBEDDING_BINDING_API_KEY=your_api_key
289+
### Gemini embedding requires sending dimension to server
290+
# EMBEDDING_SEND_DIM=true
291+
280292
### Jina AI Embedding
281293
# EMBEDDING_BINDING=jina
282294
# EMBEDDING_BINDING_HOST=https://api.jina.ai/v1/embeddings

lightrag/api/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,11 @@ def parse_args() -> argparse.Namespace:
445445
"EMBEDDING_BATCH_NUM", DEFAULT_EMBEDDING_BATCH_NUM, int
446446
)
447447

448+
# Embedding token limit configuration
449+
args.embedding_token_limit = get_env_value(
450+
"EMBEDDING_TOKEN_LIMIT", None, int, special_none=True
451+
)
452+
448453
ollama_server_infos.LIGHTRAG_NAME = args.simulated_model_name
449454
ollama_server_infos.LIGHTRAG_TAG = args.simulated_model_tag
450455

lightrag/api/lightrag_server.py

Lines changed: 151 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -618,33 +618,108 @@ def create_llm_model_kwargs(binding: str, args, llm_timeout: int) -> dict:
618618

619619
def create_optimized_embedding_function(
620620
config_cache: LLMConfigCache, binding, model, host, api_key, args
621-
):
621+
) -> EmbeddingFunc:
622622
"""
623-
Create optimized embedding function with pre-processed configuration for applicable bindings.
624-
Uses lazy imports for all bindings and avoids repeated configuration parsing.
623+
Create optimized embedding function and return an EmbeddingFunc instance
624+
with proper max_token_size inheritance from provider defaults.
625+
626+
This function:
627+
1. Imports the provider embedding function
628+
2. Extracts max_token_size and embedding_dim from provider if it's an EmbeddingFunc
629+
3. Creates an optimized wrapper that calls the underlying function directly (avoiding double-wrapping)
630+
4. Returns a properly configured EmbeddingFunc instance
625631
"""
626632

633+
# Step 1: Import provider function and extract default attributes
634+
provider_func = None
635+
provider_max_token_size = None
636+
provider_embedding_dim = None
637+
638+
try:
639+
if binding == "openai":
640+
from lightrag.llm.openai import openai_embed
641+
642+
provider_func = openai_embed
643+
elif binding == "ollama":
644+
from lightrag.llm.ollama import ollama_embed
645+
646+
provider_func = ollama_embed
647+
elif binding == "gemini":
648+
from lightrag.llm.gemini import gemini_embed
649+
650+
provider_func = gemini_embed
651+
elif binding == "jina":
652+
from lightrag.llm.jina import jina_embed
653+
654+
provider_func = jina_embed
655+
elif binding == "azure_openai":
656+
from lightrag.llm.azure_openai import azure_openai_embed
657+
658+
provider_func = azure_openai_embed
659+
elif binding == "aws_bedrock":
660+
from lightrag.llm.bedrock import bedrock_embed
661+
662+
provider_func = bedrock_embed
663+
elif binding == "lollms":
664+
from lightrag.llm.lollms import lollms_embed
665+
666+
provider_func = lollms_embed
667+
668+
# Extract attributes if provider is an EmbeddingFunc
669+
if provider_func and isinstance(provider_func, EmbeddingFunc):
670+
provider_max_token_size = provider_func.max_token_size
671+
provider_embedding_dim = provider_func.embedding_dim
672+
logger.debug(
673+
f"Extracted from {binding} provider: "
674+
f"max_token_size={provider_max_token_size}, "
675+
f"embedding_dim={provider_embedding_dim}"
676+
)
677+
except ImportError as e:
678+
logger.warning(f"Could not import provider function for {binding}: {e}")
679+
680+
# Step 2: Apply priority (user config > provider default)
681+
# For max_token_size: explicit env var > provider default > None
682+
final_max_token_size = args.embedding_token_limit or provider_max_token_size
683+
# For embedding_dim: user config (always has value) takes priority
684+
# Only use provider default if user config is explicitly None (which shouldn't happen)
685+
final_embedding_dim = (
686+
args.embedding_dim if args.embedding_dim else provider_embedding_dim
687+
)
688+
689+
# Step 3: Create optimized embedding function (calls underlying function directly)
627690
async def optimized_embedding_function(texts, embedding_dim=None):
628691
try:
629692
if binding == "lollms":
630693
from lightrag.llm.lollms import lollms_embed
631694

632-
return await lollms_embed(
695+
# Get real function, skip EmbeddingFunc wrapper if present
696+
actual_func = (
697+
lollms_embed.func
698+
if isinstance(lollms_embed, EmbeddingFunc)
699+
else lollms_embed
700+
)
701+
return await actual_func(
633702
texts, embed_model=model, host=host, api_key=api_key
634703
)
635704
elif binding == "ollama":
636705
from lightrag.llm.ollama import ollama_embed
637706

638-
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
707+
# Get real function, skip EmbeddingFunc wrapper if present
708+
actual_func = (
709+
ollama_embed.func
710+
if isinstance(ollama_embed, EmbeddingFunc)
711+
else ollama_embed
712+
)
713+
714+
# Use pre-processed configuration if available
639715
if config_cache.ollama_embedding_options is not None:
640716
ollama_options = config_cache.ollama_embedding_options
641717
else:
642-
# Fallback for cases where config cache wasn't initialized properly
643718
from lightrag.llm.binding_options import OllamaEmbeddingOptions
644719

645720
ollama_options = OllamaEmbeddingOptions.options_dict(args)
646721

647-
return await ollama_embed(
722+
return await actual_func(
648723
texts,
649724
embed_model=model,
650725
host=host,
@@ -654,15 +729,30 @@ async def optimized_embedding_function(texts, embedding_dim=None):
654729
elif binding == "azure_openai":
655730
from lightrag.llm.azure_openai import azure_openai_embed
656731

657-
return await azure_openai_embed(texts, model=model, api_key=api_key)
732+
actual_func = (
733+
azure_openai_embed.func
734+
if isinstance(azure_openai_embed, EmbeddingFunc)
735+
else azure_openai_embed
736+
)
737+
return await actual_func(texts, model=model, api_key=api_key)
658738
elif binding == "aws_bedrock":
659739
from lightrag.llm.bedrock import bedrock_embed
660740

661-
return await bedrock_embed(texts, model=model)
741+
actual_func = (
742+
bedrock_embed.func
743+
if isinstance(bedrock_embed, EmbeddingFunc)
744+
else bedrock_embed
745+
)
746+
return await actual_func(texts, model=model)
662747
elif binding == "jina":
663748
from lightrag.llm.jina import jina_embed
664749

665-
return await jina_embed(
750+
actual_func = (
751+
jina_embed.func
752+
if isinstance(jina_embed, EmbeddingFunc)
753+
else jina_embed
754+
)
755+
return await actual_func(
666756
texts,
667757
embedding_dim=embedding_dim,
668758
base_url=host,
@@ -671,16 +761,21 @@ async def optimized_embedding_function(texts, embedding_dim=None):
671761
elif binding == "gemini":
672762
from lightrag.llm.gemini import gemini_embed
673763

674-
# Use pre-processed configuration if available, otherwise fallback to dynamic parsing
764+
actual_func = (
765+
gemini_embed.func
766+
if isinstance(gemini_embed, EmbeddingFunc)
767+
else gemini_embed
768+
)
769+
770+
# Use pre-processed configuration if available
675771
if config_cache.gemini_embedding_options is not None:
676772
gemini_options = config_cache.gemini_embedding_options
677773
else:
678-
# Fallback for cases where config cache wasn't initialized properly
679774
from lightrag.llm.binding_options import GeminiEmbeddingOptions
680775

681776
gemini_options = GeminiEmbeddingOptions.options_dict(args)
682777

683-
return await gemini_embed(
778+
return await actual_func(
684779
texts,
685780
model=model,
686781
base_url=host,
@@ -691,7 +786,12 @@ async def optimized_embedding_function(texts, embedding_dim=None):
691786
else: # openai and compatible
692787
from lightrag.llm.openai import openai_embed
693788

694-
return await openai_embed(
789+
actual_func = (
790+
openai_embed.func
791+
if isinstance(openai_embed, EmbeddingFunc)
792+
else openai_embed
793+
)
794+
return await actual_func(
695795
texts,
696796
model=model,
697797
base_url=host,
@@ -701,7 +801,21 @@ async def optimized_embedding_function(texts, embedding_dim=None):
701801
except ImportError as e:
702802
raise Exception(f"Failed to import {binding} embedding: {e}")
703803

704-
return optimized_embedding_function
804+
# Step 4: Wrap in EmbeddingFunc and return
805+
embedding_func_instance = EmbeddingFunc(
806+
embedding_dim=final_embedding_dim,
807+
func=optimized_embedding_function,
808+
max_token_size=final_max_token_size,
809+
send_dimensions=False, # Will be set later based on binding requirements
810+
)
811+
812+
# Log final embedding configuration
813+
logger.info(
814+
f"Embedding config: binding={binding} model={model} "
815+
f"embedding_dim={final_embedding_dim} max_token_size={final_max_token_size}"
816+
)
817+
818+
return embedding_func_instance
705819

706820
llm_timeout = get_env_value("LLM_TIMEOUT", DEFAULT_LLM_TIMEOUT, int)
707821
embedding_timeout = get_env_value(
@@ -735,25 +849,24 @@ async def bedrock_model_complete(
735849
**kwargs,
736850
)
737851

738-
# Create embedding function with optimized configuration
852+
# Create embedding function with optimized configuration and max_token_size inheritance
739853
import inspect
740854

741-
# Create the optimized embedding function
742-
optimized_embedding_func = create_optimized_embedding_function(
855+
# Create the EmbeddingFunc instance (now returns complete EmbeddingFunc with max_token_size)
856+
embedding_func = create_optimized_embedding_function(
743857
config_cache=config_cache,
744858
binding=args.embedding_binding,
745859
model=args.embedding_model,
746860
host=args.embedding_binding_host,
747861
api_key=args.embedding_binding_api_key,
748-
args=args, # Pass args object for fallback option generation
862+
args=args,
749863
)
750864

751865
# Get embedding_send_dim from centralized configuration
752866
embedding_send_dim = args.embedding_send_dim
753867

754-
# Check if the function signature has embedding_dim parameter
755-
# Note: Since optimized_embedding_func is an async function, inspect its signature
756-
sig = inspect.signature(optimized_embedding_func)
868+
# Check if the underlying function signature has embedding_dim parameter
869+
sig = inspect.signature(embedding_func.func)
757870
has_embedding_dim_param = "embedding_dim" in sig.parameters
758871

759872
# Determine send_dimensions value based on binding type
@@ -771,18 +884,27 @@ async def bedrock_model_complete(
771884
else:
772885
dimension_control = "by not hasparam"
773886

887+
# Set send_dimensions on the EmbeddingFunc instance
888+
embedding_func.send_dimensions = send_dimensions
889+
774890
logger.info(
775891
f"Send embedding dimension: {send_dimensions} {dimension_control} "
776-
f"(dimensions={args.embedding_dim}, has_param={has_embedding_dim_param}, "
892+
f"(dimensions={embedding_func.embedding_dim}, has_param={has_embedding_dim_param}, "
777893
f"binding={args.embedding_binding})"
778894
)
779895

780-
# Create EmbeddingFunc with send_dimensions attribute
781-
embedding_func = EmbeddingFunc(
782-
embedding_dim=args.embedding_dim,
783-
func=optimized_embedding_func,
784-
send_dimensions=send_dimensions,
785-
)
896+
# Log max_token_size source
897+
if embedding_func.max_token_size:
898+
source = (
899+
"env variable"
900+
if args.embedding_token_limit
901+
else f"{args.embedding_binding} provider default"
902+
)
903+
logger.info(
904+
f"Embedding max_token_size: {embedding_func.max_token_size} (from {source})"
905+
)
906+
else:
907+
logger.info("Embedding max_token_size: not set (90% token warning disabled)")
786908

787909
# Configure rerank function based on args.rerank_bindingparameter
788910
rerank_model_func = None

lightrag/lightrag.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ class LightRAG:
276276
embedding_func: EmbeddingFunc | None = field(default=None)
277277
"""Function for computing text embeddings. Must be set before use."""
278278

279+
embedding_token_limit: int | None = field(default=None, init=False)
280+
"""Token limit for embedding model. Set automatically from embedding_func.max_token_size in __post_init__."""
281+
279282
embedding_batch_num: int = field(default=int(os.getenv("EMBEDDING_BATCH_NUM", 10)))
280283
"""Batch size for embedding computations."""
281284

@@ -519,6 +522,16 @@ def __post_init__(self):
519522
logger.debug(f"LightRAG init with param:\n {_print_config}\n")
520523

521524
# Init Embedding
525+
# Step 1: Capture max_token_size before applying decorator (decorator strips dataclass attributes)
526+
embedding_max_token_size = None
527+
if self.embedding_func and hasattr(self.embedding_func, "max_token_size"):
528+
embedding_max_token_size = self.embedding_func.max_token_size
529+
logger.debug(
530+
f"Captured embedding max_token_size: {embedding_max_token_size}"
531+
)
532+
self.embedding_token_limit = embedding_max_token_size
533+
534+
# Step 2: Apply priority wrapper decorator
522535
self.embedding_func = priority_limit_async_func_call(
523536
self.embedding_func_max_async,
524537
llm_timeout=self.default_embedding_timeout,

0 commit comments

Comments
 (0)