@@ -99,6 +99,7 @@ def __init__(
99
99
memory_folder_name : str = "kb_s2" ,
100
100
kb_release_tag : str = "v0.2.2" ,
101
101
embedding_engine_type : str = "openai" ,
102
+ embedding_engine_params : Dict = {},
102
103
):
103
104
"""Initialize AgentS2
104
105
@@ -113,6 +114,7 @@ def __init__(
113
114
memory_folder_name: Name of memory folder. Defaults to "kb_s2".
114
115
kb_release_tag: Release tag for knowledge base. Defaults to "v0.2.2".
115
116
embedding_engine_type: Embedding engine to use for knowledge base. Defaults to "openai". Supports "openai" and "gemini".
117
+ embedding_engine_params: Parameters for embedding engine. Defaults to {}.
116
118
"""
117
119
super ().__init__ (
118
120
engine_params ,
@@ -155,11 +157,13 @@ def __init__(
155
157
)
156
158
157
159
if embedding_engine_type == "openai" :
158
- self .embedding_engine = OpenAIEmbeddingEngine ()
160
+ self .embedding_engine = OpenAIEmbeddingEngine (** embedding_engine_params )
159
161
elif embedding_engine_type == "gemini" :
160
- self .embedding_engine = GeminiEmbeddingEngine ()
162
+ self .embedding_engine = GeminiEmbeddingEngine (** embedding_engine_params )
161
163
elif embedding_engine_type == "azure" :
162
- self .embedding_engine = AzureOpenAIEmbeddingEngine ()
164
+ self .embedding_engine = AzureOpenAIEmbeddingEngine (
165
+ ** embedding_engine_params
166
+ )
163
167
164
168
self .reset ()
165
169
0 commit comments