Skip to content

Commit e2435d6

Browse files
committed
support gemini embeddings
1 parent d47e574 commit e2435d6

File tree

7 files changed

+40
-17
lines changed

7 files changed

+40
-17
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ agent = AgentS2(
307307
action_space="pyautogui",
308308
observation_type="screenshot",
309309
search_engine="Perplexica" # Assuming you have set up Perplexica.
310+
embedding_engine_type="openai" # Supports "gemini", "openai"
310311
)
311312
```
312313

gui_agents/s2/agents/agent_s.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from gui_agents.s2.agents.manager import Manager
1010
from gui_agents.s2.utils.common_utils import Node
1111
from gui_agents.utils import download_kb_data
12+
from gui_agents.s2.core.engine import OpenAIEmbeddingEngine, GeminiEmbeddingEngine
1213

1314
logger = logging.getLogger("desktopenv.agent")
1415

@@ -93,6 +94,7 @@ def __init__(
9394
memory_root_path: str = os.getcwd(),
9495
memory_folder_name: str = "kb_s2",
9596
kb_release_tag: str = "v0.2.2",
97+
embedding_engine_type: str = "openai",
9698
):
9799
"""Initialize AgentS2
98100
@@ -106,6 +108,7 @@ def __init__(
106108
memory_root_path: Path to memory directory. Defaults to current working directory.
107109
memory_folder_name: Name of memory folder. Defaults to "kb_s2".
108110
kb_release_tag: Release tag for knowledge base. Defaults to "v0.2.2".
111+
embedding_engine_type: Embedding engine to use for knowledge base. Defaults to "openai". Supports "openai" and "gemini".
109112
"""
110113
super().__init__(
111114
engine_params,
@@ -147,23 +150,30 @@ def __init__(
147150
"Note, the knowledge is continually updated during inference. Deleting the knowledge base will wipe out all experience gained since the last knowledge base download."
148151
)
149152

153+
if embedding_engine_type == "openai":
154+
self.embedding_engine = OpenAIEmbeddingEngine()
155+
elif embedding_engine_type == "gemini":
156+
self.embedding_engine = GeminiEmbeddingEngine()
157+
150158
self.reset()
151159

152160
def reset(self) -> None:
153161
"""Reset agent state and initialize components"""
154162
# Initialize core components
155163
self.planner = Manager(
156-
self.engine_params,
157-
self.grounding_agent,
158-
platform=self.platform,
159-
search_engine=self.engine,
164+
engine_params=self.engine_params,
165+
grounding_agent=self.grounding_agent,
160166
local_kb_path=self.local_kb_path,
167+
embedding_engine=self.embedding_engine,
168+
search_engine=self.engine,
169+
platform=self.platform,
161170
)
162171
self.executor = Worker(
163-
self.engine_params,
164-
self.grounding_agent,
165-
platform=self.platform,
172+
engine_params=self.engine_params,
173+
grounding_agent=self.grounding_agent,
166174
local_kb_path=self.local_kb_path,
175+
embedding_engine=self.embedding_engine,
176+
platform=self.platform,
167177
)
168178

169179
# Reset state variables

gui_agents/s2/agents/manager.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from gui_agents.s2.core.module import BaseModule
99
from gui_agents.s2.core.knowledge import KnowledgeBase
1010
from gui_agents.s2.memory.procedural_memory import PROCEDURAL_MEMORY
11+
from gui_agents.s2.core.engine import OpenAIEmbeddingEngine
1112
from gui_agents.s2.utils.common_utils import (
1213
Dag,
1314
Node,
@@ -27,6 +28,7 @@ def __init__(
2728
engine_params: Dict,
2829
grounding_agent: ACI,
2930
local_kb_path: str,
31+
embedding_engine=OpenAIEmbeddingEngine(),
3032
search_engine: Optional[str] = None,
3133
multi_round: bool = False,
3234
platform: str = platform.system().lower(),
@@ -55,10 +57,12 @@ def __init__(
5557

5658
self.local_kb_path = local_kb_path
5759

60+
self.embedding_engine = embedding_engine
5861
self.knowledge_base = KnowledgeBase(
59-
self.local_kb_path,
60-
platform,
61-
engine_params,
62+
embedding_engine=self.embedding_engine,
63+
local_kb_path=self.local_kb_path,
64+
platform=platform,
65+
engine_params=engine_params,
6266
)
6367

6468
self.planner_history = []

gui_agents/s2/agents/worker.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from gui_agents.s2.core.module import BaseModule
99
from gui_agents.s2.core.knowledge import KnowledgeBase
1010
from gui_agents.s2.memory.procedural_memory import PROCEDURAL_MEMORY
11+
from gui_agents.s2.core.engine import OpenAIEmbeddingEngine
1112
from gui_agents.s2.utils.common_utils import (
1213
Node,
1314
calculate_tokens,
@@ -26,8 +27,8 @@ def __init__(
2627
engine_params: Dict,
2728
grounding_agent: ACI,
2829
local_kb_path: str,
30+
embedding_engine=OpenAIEmbeddingEngine(),
2931
platform: str = platform.system().lower(),
30-
search_engine: str = "perplexica",
3132
enable_reflection: bool = True,
3233
use_subtask_experience: bool = True,
3334
):
@@ -42,8 +43,6 @@ def __init__(
4243
Path to knowledge base
4344
platform: str
4445
OS platform the agent runs on (darwin, linux, windows)
45-
search_engine: str
46-
The search engine to use
4746
enable_reflection: bool
4847
Whether to enable reflection
4948
use_subtask_experience: bool
@@ -53,7 +52,7 @@ def __init__(
5352

5453
self.grounding_agent = grounding_agent
5554
self.local_kb_path = local_kb_path
56-
self.search_engine = search_engine
55+
self.embedding_engine = embedding_engine
5756
self.enable_reflection = enable_reflection
5857
self.use_subtask_experience = use_subtask_experience
5958
self.reset()
@@ -74,6 +73,7 @@ def reset(self):
7473
)
7574

7675
self.knowledge_base = KnowledgeBase(
76+
embedding_engine=self.embedding_engine,
7777
local_kb_path=self.local_kb_path,
7878
platform=self.platform,
7979
engine_params=self.engine_params,

gui_agents/s2/cli_app.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ def main():
206206
help="The API key of the grounding model.",
207207
)
208208

209+
parser.add_argument(
210+
"--embedding_engine_type",
211+
type=str,
212+
default="openai",
213+
help="Specify the embedding engine type (supports openai, gemini)",
214+
)
215+
209216
args = parser.parse_args()
210217
assert (
211218
args.grounding_model_provider and args.grounding_model
@@ -257,6 +264,7 @@ def main():
257264
action_space="pyautogui",
258265
observation_type="mixed",
259266
search_engine=None,
267+
embedding_engine_type=args.embedding_engine_type,
260268
)
261269

262270
while True:

gui_agents/s2/core/knowledge.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from gui_agents.s2.core.module import BaseModule
99
from gui_agents.s2.memory.procedural_memory import PROCEDURAL_MEMORY
10-
from gui_agents.s2.core.engine import OpenAIEmbeddingEngine
1110
from gui_agents.s2.utils.common_utils import (
1211
call_llm_safe,
1312
load_embeddings,
@@ -20,6 +19,7 @@
2019
class KnowledgeBase(BaseModule):
2120
def __init__(
2221
self,
22+
embedding_engine,
2323
local_kb_path: str,
2424
platform: str,
2525
engine_params: Dict,
@@ -30,8 +30,7 @@ def __init__(
3030
self.local_kb_path = local_kb_path
3131

3232
# initialize embedding engine
33-
# TODO: Support other embedding engines
34-
self.embedding_engine = OpenAIEmbeddingEngine()
33+
self.embedding_engine = embedding_engine
3534

3635
# Initialize paths for different memory types
3736
self.episodic_memory_path = os.path.join(

osworld_setup/s2/run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ def test(args: argparse.Namespace, test_all_meta: dict) -> None:
219219
memory_root_path=os.getcwd(),
220220
memory_folder_name=args.kb_name,
221221
kb_release_tag="v0.2.2",
222+
embedding_engine_type="openai",
222223
)
223224

224225
env = DesktopEnv(

0 commit comments

Comments
 (0)