Skip to content

fix(memory): Replace the outdated manual way of creating the ChatMemoryRepository #3532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.neo4j.driver.Driver;

import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepository;
import org.springframework.ai.chat.memory.repository.neo4j.Neo4jChatMemoryRepositoryConfig;
import org.springframework.ai.model.chat.memory.autoconfigure.ChatMemoryAutoConfiguration;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
Expand All @@ -43,17 +42,15 @@ public class Neo4jChatMemoryRepositoryAutoConfiguration {
@ConditionalOnMissingBean
public Neo4jChatMemoryRepository neo4jChatMemoryRepository(Neo4jChatMemoryRepositoryProperties properties,
Driver driver) {

var builder = Neo4jChatMemoryRepositoryConfig.builder()
.withMediaLabel(properties.getMediaLabel())
.withMessageLabel(properties.getMessageLabel())
.withMetadataLabel(properties.getMetadataLabel())
.withSessionLabel(properties.getSessionLabel())
.withToolCallLabel(properties.getToolCallLabel())
.withToolResponseLabel(properties.getToolResponseLabel())
.withDriver(driver);

return new Neo4jChatMemoryRepository(builder.build());
return Neo4jChatMemoryRepository.builder()
.driver(driver)
.mediaLabel(properties.getMediaLabel())
.messageLabel(properties.getMessageLabel())
.metadataLabel(properties.getMetadataLabel())
.sessionLabel(properties.getSessionLabel())
.toolCallLabel(properties.getToolCallLabel())
.toolResponseLabel(properties.getToolResponseLabel())
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import java.util.UUID;
import java.util.stream.Collectors;

import org.neo4j.driver.Driver;
import org.neo4j.driver.Session;
import org.neo4j.driver.Transaction;
import org.neo4j.driver.TransactionContext;
Expand All @@ -38,6 +39,7 @@
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.content.Media;
import org.springframework.ai.content.MediaContent;
import org.springframework.util.Assert;
import org.springframework.util.MimeType;

/**
Expand All @@ -52,7 +54,7 @@ public final class Neo4jChatMemoryRepository implements ChatMemoryRepository {

private final Neo4jChatMemoryRepositoryConfig config;

public Neo4jChatMemoryRepository(Neo4jChatMemoryRepositoryConfig config) {
private Neo4jChatMemoryRepository(Neo4jChatMemoryRepositoryConfig config) {
this.config = config;
}

Expand Down Expand Up @@ -326,4 +328,56 @@ private List<Map<String, Object>> convertMediaToMap(List<Media> media) {
return mediaMaps;
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {

private final Neo4jChatMemoryRepositoryConfig.Builder builder = Neo4jChatMemoryRepositoryConfig.builder();

private Builder() {
}

public Builder driver(Driver driver) {
this.builder.withDriver(driver);
return this;
}

public Builder sessionLabel(String sessionLabel) {
this.builder.withSessionLabel(sessionLabel);
return this;
}

public Builder toolCallLabel(String toolCallLabel) {
this.builder.withToolCallLabel(toolCallLabel);
return this;
}

public Builder metadataLabel(String metadataLabel) {
this.builder.withMetadataLabel(metadataLabel);
return this;
}

public Builder messageLabel(String messageLabel) {
this.builder.withMessageLabel(messageLabel);
return this;
}

public Builder toolResponseLabel(String toolResponseLabel) {
this.builder.withToolResponseLabel(toolResponseLabel);
return this;
}

public Builder mediaLabel(String mediaLabel) {
this.builder.withMediaLabel(mediaLabel);
return this;
}

public Neo4jChatMemoryRepository build() {
Assert.notNull(this.builder.getDriver(), "Driver cannot be null");
return new Neo4jChatMemoryRepository(this.builder.build());
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Neo4jChatMemoryRepositoryIT {
.withoutAuthentication()
.withExposedPorts(7474, 7687);

private ChatMemoryRepository chatMemoryRepository;
private Neo4jChatMemoryRepository chatMemoryRepository;

private Driver driver;

Expand All @@ -76,8 +76,8 @@ class Neo4jChatMemoryRepositoryIT {
@BeforeEach
void setUp() {
this.driver = Neo4jDriverFactory.create(neo4jContainer.getBoltUrl());
this.config = Neo4jChatMemoryRepositoryConfig.builder().withDriver(this.driver).build();
this.chatMemoryRepository = new Neo4jChatMemoryRepository(this.config);
this.chatMemoryRepository = Neo4jChatMemoryRepository.builder().driver(driver).build();
this.config = chatMemoryRepository.getConfig();
}

@AfterEach
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ If you'd rather create the `CassandraChatMemoryRepository` manually, you can do
[source,java]
----
ChatMemoryRepository chatMemoryRepository = CassandraChatMemoryRepository
.create(CassandraChatMemoryConfig.builder().withCqlSession(cqlSession));
.create(CassandraChatMemoryRepositoryConfig.builder().withCqlSession(cqlSession).build());

ChatMemory chatMemory = MessageWindowChatMemory.builder()
.chatMemoryRepository(chatMemoryRepository)
Expand Down