Skip to content

Commit 7d1d0ca

Browse files
fixes memgraph async attirbute error (#3209)
1 parent 89b67e0 commit 7d1d0ca

File tree

2 files changed

+32
-12
lines changed

2 files changed

+32
-12
lines changed

mem0/memory/main.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,12 @@
3131
process_telemetry_filters,
3232
remove_code_blocks,
3333
)
34-
from mem0.utils.factory import EmbedderFactory, LlmFactory, VectorStoreFactory
34+
from mem0.utils.factory import (
35+
EmbedderFactory,
36+
GraphStoreFactory,
37+
LlmFactory,
38+
VectorStoreFactory,
39+
)
3540

3641

3742
def _build_filters_and_metadata(
@@ -136,14 +141,8 @@ def __init__(self, config: MemoryConfig = MemoryConfig()):
136141
self.enable_graph = False
137142

138143
if self.config.graph_store.config:
139-
if self.config.graph_store.provider == "memgraph":
140-
from mem0.memory.memgraph_memory import MemoryGraph
141-
elif self.config.graph_store.provider == "neptune":
142-
from mem0.graphs.neptune.main import MemoryGraph
143-
else:
144-
from mem0.memory.graph_memory import MemoryGraph
145-
146-
self.graph = MemoryGraph(self.config)
144+
provider = self.config.graph_store.provider
145+
self.graph = GraphStoreFactory.create(provider, self.config)
147146
self.enable_graph = True
148147
else:
149148
self.graph = None
@@ -989,9 +988,8 @@ def __init__(self, config: MemoryConfig = MemoryConfig()):
989988
self.enable_graph = False
990989

991990
if self.config.graph_store.config:
992-
from mem0.memory.graph_memory import MemoryGraph
993-
994-
self.graph = MemoryGraph(self.config)
991+
provider = self.config.graph_store.provider
992+
self.graph = GraphStoreFactory.create(provider, self.config)
995993
self.enable_graph = True
996994
else:
997995
self.graph = None

mem0/utils/factory.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,25 @@ def create(cls, provider_name, config):
106106
def reset(cls, instance):
107107
instance.reset()
108108
return instance
109+
110+
111+
class GraphStoreFactory:
112+
"""
113+
Factory for creating MemoryGraph instances for different graph store providers.
114+
Usage: GraphStoreFactory.create(provider_name, config)
115+
"""
116+
117+
provider_to_class = {
118+
"memgraph": "mem0.memory.memgraph_memory.MemoryGraph",
119+
"neptune": "mem0.graphs.neptune.main.MemoryGraph",
120+
"default": "mem0.memory.graph_memory.MemoryGraph",
121+
}
122+
123+
@classmethod
124+
def create(cls, provider_name, config):
125+
class_type = cls.provider_to_class.get(provider_name, cls.provider_to_class["default"])
126+
try:
127+
GraphClass = load_class(class_type)
128+
except (ImportError, AttributeError) as e:
129+
raise ImportError(f"Could not import MemoryGraph for provider '{provider_name}': {e}")
130+
return GraphClass(config)

0 commit comments

Comments
 (0)