|
31 | 31 | process_telemetry_filters,
|
32 | 32 | remove_code_blocks,
|
33 | 33 | )
|
34 |
| -from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory |
| 34 | +from mem0.utils.factory import ( |
| 35 | + EmbedderFactory, |
| 36 | + GraphStoreFactory, |
| 37 | + LlmFactory, |
| 38 | + VectorStoreFactory, |
| 39 | +) |
35 | 40 |
|
36 | 41 |
|
37 | 42 | def _build_filters_and_metadata(
|
@@ -136,14 +141,8 @@ def __init__(self, config: MemoryConfig = MemoryConfig()):
|
136 | 141 | self.enable_graph = False
|
137 | 142 |
|
138 | 143 | if self.config.graph_store.config:
|
139 |
| - if self.config.graph_store.provider == "memgraph": |
140 |
| - from mem0.memory.memgraph_memory import MemoryGraph |
141 |
| - elif self.config.graph_store.provider == "neptune": |
142 |
| - from mem0.graphs.neptune.main import MemoryGraph |
143 |
| - else: |
144 |
| - from mem0.memory.graph_memory import MemoryGraph |
145 |
| - |
146 |
| - self.graph = MemoryGraph(self.config) |
| 144 | + provider = self.config.graph_store.provider |
| 145 | + self.graph = GraphStoreFactory.create(provider, self.config) |
147 | 146 | self.enable_graph = True
|
148 | 147 | else:
|
149 | 148 | self.graph = None
|
@@ -989,9 +988,8 @@ def __init__(self, config: MemoryConfig = MemoryConfig()):
|
989 | 988 | self.enable_graph = False
|
990 | 989 |
|
991 | 990 | if self.config.graph_store.config:
|
992 |
| - from mem0.memory.graph_memory import MemoryGraph |
993 |
| - |
994 |
| - self.graph = MemoryGraph(self.config) |
| 991 | + provider = self.config.graph_store.provider |
| 992 | + self.graph = GraphStoreFactory.create(provider, self.config) |
995 | 993 | self.enable_graph = True
|
996 | 994 | else:
|
997 | 995 | self.graph = None
|
|
0 commit comments