diff --git a/README.md b/README.md index 35486da..bc38ff2 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,10 @@ Dies geschieht mithilfe von `docker compose -f elasticsearch/docker-compose.yml ### REST-Requests absetzen In der Datei rest-api.http sind mehrere REST-Requests dokumentiert. Diese können in IntelliJ per Klick aufgerufen werden +### Milvus Vektordatenbank starten +Damit die Anwendung mit der Vektordatenbank Milvus verwendet werden kann, muss per Docker oder minikube/helm eine Milvus instanz gestartet werden. Anschließend können host und port in der application.yml gesetzt werden. +Durch die Konfiguration von milvus host & port werden alle notwendigen Beans von Spring-AI automatisch initialisiert + ## Anfrage an ChatGPT Um eine Anfrage an ChatGPT zu stellen, kann der Endpunkt `GET localhost:8080/chatgpt?message=MeineAnfrage` aufgerufen werden. Die Anfrage wird über Spring AI direkt an ChatGPT weitergeleitet und das Ergebnis zurückgegeben. diff --git a/build.gradle b/build.gradle index 2ba85f1..6ad23c9 100644 --- a/build.gradle +++ b/build.gradle @@ -1,6 +1,6 @@ plugins { id 'java' - id 'org.springframework.boot' version '3.2.1' + id 'org.springframework.boot' version '3.2.2' id 'io.spring.dependency-management' version '1.1.4' } @@ -22,10 +22,12 @@ dependencies { implementation 'org.springframework.boot:spring-boot-starter' implementation 'org.springframework.boot:spring-boot-starter-web' implementation 'org.springframework.boot:spring-boot-starter-data-elasticsearch' + + // Spring AI Dependencies implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:0.8.0-SNAPSHOT' - testImplementation 'org.springframework.boot:spring-boot-starter-test' -} + implementation 'org.springframework.ai:spring-ai-milvus-store:0.8.0-SNAPSHOT' -tasks.named('test') { - useJUnitPlatform() + testImplementation 'org.springframework.boot:spring-boot-starter-test' + testImplementation 'org.junit.jupiter:junit-jupiter:5.9.2' + testRuntimeOnly 'org.junit.platform:junit-platform-launcher' } diff --git a/rest-api.http b/rest-api.http index 932b67f..0f7510a 100644 --- a/rest-api.http +++ b/rest-api.http @@ -16,3 +16,11 @@ POST localhost:8080/rag/elasticsearch/ingest GET localhost:8080/rag/elasticsearch?message=ultimate%20mountain%20bike +### Ingest data to milvus vector store +POST localhost:8080/rag/milvus/ingest + +### RAG-Query using milvus vector store +GET localhost:8080/rag/milvus?message=ultimate%20mountain%20bike + + + diff --git a/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java b/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java index 9189ca4..8c4de0d 100644 --- a/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java +++ b/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java @@ -1,6 +1,6 @@ package de.cofinpro.springai.chatgpt; -import org.springframework.ai.openai.client.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; @@ -20,6 +20,6 @@ public ChatGPTController(OpenAiChatClient openAiChatClient) { @GetMapping public String queryGpt(@RequestParam("message") String message) { - return openAiChatClient.generate(message); + return openAiChatClient.call(message); } } diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java index cb36041..0854afe 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java @@ -1,13 +1,11 @@ package de.cofinpro.springai.retrieval_augmented_generation; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingClient; -import org.springframework.ai.openai.client.OpenAiChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.reader.JsonReader; -import org.springframework.ai.retriever.VectorStoreRetriever; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.core.io.Resource; @@ -19,8 +17,6 @@ public abstract class AbstractRetrievalAugmentedGenerationService { private final VectorStore vectorStore; - private final VectorStoreRetriever vectorStoreRetriever; - private final OpenAiChatClient openAiChatClient; private final SystemPromptTemplate systemPromptTemplate; @@ -31,7 +27,6 @@ public AbstractRetrievalAugmentedGenerationService(VectorStore vectorStore, Open Resource bikesResource) { this.vectorStore = vectorStore; this.openAiChatClient = openAiChatClient; - vectorStoreRetriever = new VectorStoreRetriever(vectorStore); systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource); this.bikesResource = bikesResource; } @@ -46,10 +41,10 @@ public void ingestDocuments() { } public String retrievalAugmentedGeneration(String message) { - final var similarDocuments = vectorStoreRetriever.retrieve(message); + final var similarDocuments = vectorStore.similaritySearch(message); final var joinedDocuments = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n")); final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", joinedDocuments)); final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message))); - return openAiChatClient.generate(prompt).getGeneration().getContent(); + return openAiChatClient.call(prompt).getResult().toString(); } } diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java index b2ef7b6..2e89f48 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java @@ -1,7 +1,7 @@ package de.cofinpro.springai.retrieval_augmented_generation.elasticsearch; import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService; -import org.springframework.ai.openai.client.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java new file mode 100644 index 0000000..aa90bd7 --- /dev/null +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java @@ -0,0 +1,27 @@ +package de.cofinpro.springai.retrieval_augmented_generation.milvus; + +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; + +@RestController +@RequestMapping("/rag/milvus") +public class MilvusRetrievalAugmentedGenerationController { + + private final MilvusRetrievalAugmentedGenerationService milvusRetrievalAugmentedGenerationService; + + public MilvusRetrievalAugmentedGenerationController(MilvusRetrievalAugmentedGenerationService milvusRetrievalAugmentedGenerationService) { + this.milvusRetrievalAugmentedGenerationService = milvusRetrievalAugmentedGenerationService; + } + + @PostMapping("/ingest") + public ResponseEntity ingestDocuments() { + milvusRetrievalAugmentedGenerationService.ingestDocuments(); + return new ResponseEntity<>(HttpStatus.OK); + } + + @GetMapping + public String retrievalAugmentedGeneration(@RequestParam("message") String message) { + return milvusRetrievalAugmentedGenerationService.retrievalAugmentedGeneration(message); + } +} diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java new file mode 100644 index 0000000..5d9bfe1 --- /dev/null +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java @@ -0,0 +1,49 @@ +package de.cofinpro.springai.retrieval_augmented_generation.milvus; + +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.reader.JsonReader; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.Resource; +import org.springframework.stereotype.Service; + +import java.util.List; +import java.util.Map; + +@Service +public class MilvusRetrievalAugmentedGenerationService { + + private final OpenAiChatClient openAiChatClient; + private final Resource bikesResource; + private final SystemPromptTemplate systemPromptTemplate; + + private final VectorStore vectorStore; + + @Autowired + public MilvusRetrievalAugmentedGenerationService(OpenAiChatClient openAiChatClient, + VectorStore vectorStore, + @Value("classpath:/bikes.json") Resource bikesResource, + @Value("classpath:/system-prompt-template") Resource systemPromptTemplateResource) { + this.openAiChatClient = openAiChatClient; + this.vectorStore = vectorStore; + this.bikesResource = bikesResource; + this.systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource); + } + + public void ingestDocuments() { + final var jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription"); + vectorStore.add(jsonReader.get()); + } + + public String retrievalAugmentedGeneration(String message) { + final var test = vectorStore.similaritySearch(SearchRequest.query(message)); + final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", test)); + final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message))); + return openAiChatClient.call(prompt).getResult().toString(); + } +} diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java index d418275..a3e11a6 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java @@ -1,23 +1,14 @@ package de.cofinpro.springai.retrieval_augmented_generation.simple; import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService; -import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingClient; -import org.springframework.ai.openai.client.OpenAiChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.UserMessage; -import org.springframework.ai.reader.JsonReader; -import org.springframework.ai.retriever.VectorStoreRetriever; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.vectorstore.SimpleVectorStore; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; import java.io.File; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; @Service public class SimpleRetrievalAugmentedGenerationService extends AbstractRetrievalAugmentedGenerationService { diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index ea2ddd5..c6e605e 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -1,3 +1,16 @@ spring: elasticsearch: uris: http://localhost:9200 + ai: + openai: + api-key: + chat: + options: + temperature: 0.5 + user: + vectorstore: + milvus: + client: + host: localhost + port: 56875 +