From 40ef148f5806cbd64e2a721c9e36835a2ad40bae Mon Sep 17 00:00:00 2001 From: "Leibmann, Ruben EXT" Date: Wed, 24 Jan 2024 16:25:44 +0100 Subject: [PATCH 1/4] fix imports fix testing errors in preparation of gradle 9.0 --- build.gradle | 6 ++---- .../cofinpro/springai/chatgpt/ChatGPTController.java | 2 +- .../AbstractRetrievalAugmentedGenerationService.java | 9 ++------- ...sticsearchRetrievalAugmentedGenerationService.java | 2 +- .../SimpleRetrievalAugmentedGenerationService.java | 11 +---------- src/main/resources/application.yml | 4 ++++ 6 files changed, 11 insertions(+), 23 deletions(-) diff --git a/build.gradle b/build.gradle index 2ba85f1..287edce 100644 --- a/build.gradle +++ b/build.gradle @@ -24,8 +24,6 @@ dependencies { implementation 'org.springframework.boot:spring-boot-starter-data-elasticsearch' implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:0.8.0-SNAPSHOT' testImplementation 'org.springframework.boot:spring-boot-starter-test' -} - -tasks.named('test') { - useJUnitPlatform() + testImplementation 'org.junit.jupiter:junit-jupiter:5.9.2' + testRuntimeOnly 'org.junit.platform:junit-platform-launcher' } diff --git a/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java b/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java index 9189ca4..4d63edc 100644 --- a/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java +++ b/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java @@ -1,6 +1,6 @@ package de.cofinpro.springai.chatgpt; -import org.springframework.ai.openai.client.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestMapping; diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java index cb36041..468fa45 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java @@ -1,13 +1,11 @@ package de.cofinpro.springai.retrieval_augmented_generation; import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingClient; -import org.springframework.ai.openai.client.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.prompt.Prompt; import org.springframework.ai.prompt.SystemPromptTemplate; import org.springframework.ai.prompt.messages.UserMessage; import org.springframework.ai.reader.JsonReader; -import org.springframework.ai.retriever.VectorStoreRetriever; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.core.io.Resource; @@ -19,8 +17,6 @@ public abstract class AbstractRetrievalAugmentedGenerationService { private final VectorStore vectorStore; - private final VectorStoreRetriever vectorStoreRetriever; - private final OpenAiChatClient openAiChatClient; private final SystemPromptTemplate systemPromptTemplate; @@ -31,7 +27,6 @@ public AbstractRetrievalAugmentedGenerationService(VectorStore vectorStore, Open Resource bikesResource) { this.vectorStore = vectorStore; this.openAiChatClient = openAiChatClient; - vectorStoreRetriever = new VectorStoreRetriever(vectorStore); systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource); this.bikesResource = bikesResource; } @@ -46,7 +41,7 @@ public void ingestDocuments() { } public String retrievalAugmentedGeneration(String message) { - final var similarDocuments = vectorStoreRetriever.retrieve(message); + final var similarDocuments = vectorStore.similaritySearch(message); final var joinedDocuments = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n")); final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", joinedDocuments)); final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message))); diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java index b2ef7b6..2e89f48 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/elasticsearch/ElasticsearchRetrievalAugmentedGenerationService.java @@ -1,7 +1,7 @@ package de.cofinpro.springai.retrieval_augmented_generation.elasticsearch; import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService; -import org.springframework.ai.openai.client.OpenAiChatClient; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java index d418275..a3e11a6 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/simple/SimpleRetrievalAugmentedGenerationService.java @@ -1,23 +1,14 @@ package de.cofinpro.springai.retrieval_augmented_generation.simple; import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService; -import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingClient; -import org.springframework.ai.openai.client.OpenAiChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.UserMessage; -import org.springframework.ai.reader.JsonReader; -import org.springframework.ai.retriever.VectorStoreRetriever; +import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.vectorstore.SimpleVectorStore; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; import java.io.File; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; @Service public class SimpleRetrievalAugmentedGenerationService extends AbstractRetrievalAugmentedGenerationService { diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index ea2ddd5..a92bc0e 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -1,3 +1,7 @@ spring: elasticsearch: uris: http://localhost:9200 + ai: + azure: + openai: + api-key: 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 From 515c9c1222c69c14ed542f21e5fa32dc29ba35d2 Mon Sep 17 00:00:00 2001 From: Ruben Leibmann Date: Sat, 3 Feb 2024 14:47:01 +0100 Subject: [PATCH 2/4] (non-func): Milvus testing --- build.gradle | 6 +- rest-api.http | 6 + .../springai/chatgpt/ChatGPTController.java | 2 +- ...ctRetrievalAugmentedGenerationService.java | 8 +- ...etrievalAugmentedGenerationController.java | 33 ++++ ...usRetrievalAugmentedGenerationService.java | 166 ++++++++++++++++++ src/main/resources/application.yml | 15 +- 7 files changed, 227 insertions(+), 9 deletions(-) create mode 100644 src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java create mode 100644 src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java diff --git a/build.gradle b/build.gradle index 287edce..6ad23c9 100644 --- a/build.gradle +++ b/build.gradle @@ -1,6 +1,6 @@ plugins { id 'java' - id 'org.springframework.boot' version '3.2.1' + id 'org.springframework.boot' version '3.2.2' id 'io.spring.dependency-management' version '1.1.4' } @@ -22,7 +22,11 @@ dependencies { implementation 'org.springframework.boot:spring-boot-starter' implementation 'org.springframework.boot:spring-boot-starter-web' implementation 'org.springframework.boot:spring-boot-starter-data-elasticsearch' + + // Spring AI Dependencies implementation 'org.springframework.ai:spring-ai-openai-spring-boot-starter:0.8.0-SNAPSHOT' + implementation 'org.springframework.ai:spring-ai-milvus-store:0.8.0-SNAPSHOT' + testImplementation 'org.springframework.boot:spring-boot-starter-test' testImplementation 'org.junit.jupiter:junit-jupiter:5.9.2' testRuntimeOnly 'org.junit.platform:junit-platform-launcher' diff --git a/rest-api.http b/rest-api.http index 932b67f..64c2114 100644 --- a/rest-api.http +++ b/rest-api.http @@ -5,6 +5,12 @@ GET localhost:8080/chatgpt?message=hallo ### Ingest data to simple vector store POST localhost:8080/rag/simple/ingest +### Ingest data to simple vector store +POST localhost:8080/rag/milvus/ingest + +### RAG-Query using milvus vector store +GET localhost:8080/rag/milvus?message=ultimate%20mountain%20bike + ### RAG-Query using simple vector store GET localhost:8080/rag/simple?message=ultimate%20mountain%20bike diff --git a/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java b/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java index 4d63edc..8c4de0d 100644 --- a/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java +++ b/src/main/java/de/cofinpro/springai/chatgpt/ChatGPTController.java @@ -20,6 +20,6 @@ public ChatGPTController(OpenAiChatClient openAiChatClient) { @GetMapping public String queryGpt(@RequestParam("message") String message) { - return openAiChatClient.generate(message); + return openAiChatClient.call(message); } } diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java index 468fa45..0854afe 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/AbstractRetrievalAugmentedGenerationService.java @@ -1,10 +1,10 @@ package de.cofinpro.springai.retrieval_augmented_generation; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.Document; import org.springframework.ai.openai.OpenAiChatClient; -import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.SystemPromptTemplate; -import org.springframework.ai.prompt.messages.UserMessage; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; import org.springframework.ai.reader.JsonReader; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.core.io.Resource; @@ -45,6 +45,6 @@ public String retrievalAugmentedGeneration(String message) { final var joinedDocuments = similarDocuments.stream().map(Document::getContent).collect(Collectors.joining("\n")); final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", joinedDocuments)); final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message))); - return openAiChatClient.generate(prompt).getGeneration().getContent(); + return openAiChatClient.call(prompt).getResult().toString(); } } diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java new file mode 100644 index 0000000..cb045fd --- /dev/null +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java @@ -0,0 +1,33 @@ +package de.cofinpro.springai.retrieval_augmented_generation.milvus; + +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; +import org.springframework.web.bind.annotation.*; + +import java.io.IOException; + +@RestController +@RequestMapping("/rag/milvus") +public class MilvusRetrievalAugmentedGenerationController { + + private final MilvusRetrievalAugmentedGenerationService milvusRetrievalAugmentedGenerationService; + + public MilvusRetrievalAugmentedGenerationController(MilvusRetrievalAugmentedGenerationService milvusRetrievalAugmentedGenerationService) { + this.milvusRetrievalAugmentedGenerationService = milvusRetrievalAugmentedGenerationService; + } + + @PostMapping("/ingest") + public ResponseEntity ingestDocuments() { + try { + milvusRetrievalAugmentedGenerationService.ingestDocuments(); + return new ResponseEntity<>(HttpStatus.OK); + } catch (IOException e) { + return new ResponseEntity<>(HttpStatus.BAD_REQUEST); + } + } + + @GetMapping + public String retrievalAugmentedGeneration(@RequestParam("message") String message) { + return milvusRetrievalAugmentedGenerationService.retrievalAugmentedGeneration(message); + } +} diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java new file mode 100644 index 0000000..74d96f6 --- /dev/null +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java @@ -0,0 +1,166 @@ +package de.cofinpro.springai.retrieval_augmented_generation.milvus; + +import com.alibaba.fastjson.JSON; +import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService; +import io.milvus.client.MilvusServiceClient; +import io.milvus.common.clientenum.ConsistencyLevelEnum; +import io.milvus.grpc.CreateCollectionRequestOrBuilder; +import io.milvus.grpc.DataType; +import io.milvus.grpc.MutationResult; +import io.milvus.param.IndexType; +import io.milvus.param.MetricType; +import io.milvus.param.R; +import io.milvus.param.RpcStatus; +import io.milvus.param.bulkinsert.BulkInsertParam; +import com.alibaba.fastjson.JSONObject; +import io.milvus.param.collection.*; +import io.milvus.param.dml.InsertParam; +import io.milvus.param.dml.SearchParam; +import io.milvus.param.index.CreateIndexParam; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.chat.prompt.SystemPromptTemplate; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingClient; +import org.springframework.ai.openai.OpenAiChatClient; +import org.springframework.ai.reader.JsonReader; +import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.core.io.Resource; +import org.springframework.stereotype.Service; +import org.springframework.web.bind.annotation.PostMapping; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +@Service +public class MilvusRetrievalAugmentedGenerationService { + + private static final File VECTORSTORE_FILE = new File("data/vectorstore.json"); + + private final OpenAiChatClient openAiChatClient; + private final EmbeddingClient embeddingClient; + private final Resource bikesResource; + private final SystemPromptTemplate systemPromptTemplate; + private final MilvusServiceClient milvusServiceClient; + + private static final String COLLECTION_NAME = "TEST_COLLECTION"; + + @Autowired + public MilvusRetrievalAugmentedGenerationService(OpenAiChatClient openAiChatClient, EmbeddingClient embeddingClient, @Value("classpath:/bikes.json") Resource bikesResource, + @Value("classpath:/system-prompt-template") Resource systemPromptTemplateResource, MilvusServiceClient milvusServiceClient) { + this.openAiChatClient = openAiChatClient; + this.embeddingClient = embeddingClient; + this.bikesResource = bikesResource; + this.milvusServiceClient = milvusServiceClient; + this.systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource); + } + + public String retrievalAugmentedGeneration(String message) { + List> searchVectors = List.of(embeddingClient.embed(message).stream().map(Double::floatValue).toList()); + var searchParam = SearchParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withConsistencyLevel(ConsistencyLevelEnum.STRONG) + .withOutFields(Arrays.asList("document_content")) + .withTopK(5) + .withVectors(searchVectors) + .withVectorFieldName("document_vectors") + .withParams("{\"nprobe\":10, \"offset\":0}") + .build(); + + + var status = milvusServiceClient.loadCollection(LoadCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .build()); + System.out.println(status); + final var similarDocuments = milvusServiceClient.search(searchParam); + milvusServiceClient.releaseCollection(ReleaseCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .build()); + final var results = similarDocuments.toString(); + final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", results)); + final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message))); + return openAiChatClient.call(prompt).getResult().toString(); + } + + public void ingestDocuments() throws IOException { + FieldType fieldType1 = FieldType.newBuilder() + .withName("document_id") + .withDataType(DataType.Int64) + .withPrimaryKey(true) + .withAutoID(true) + .build(); + FieldType fieldType2 = FieldType.newBuilder() + .withName("document_content") + .withDataType(DataType.VarChar) + .withMaxLength(5000) + .build(); + FieldType fieldType3 = FieldType.newBuilder() + .withName("document_vectors") + .withDataType(DataType.FloatVector) + .withDimension(5000) + .build(); + CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withDescription("Test document search") + .withShardsNum(2) + .addFieldType(fieldType1) + .addFieldType(fieldType2) + .addFieldType(fieldType3) + .withEnableDynamicField(true) + .build(); + + var ret = milvusServiceClient.createCollection(createCollectionReq); + if (ret.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("Failed to create collection! Error: " + ret.getMessage()); + } + ret = milvusServiceClient.createIndex(CreateIndexParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withFieldName("document_vectors") + .withIndexType(IndexType.FLAT) + .withMetricType(MetricType.L2) + .build()); + if (ret.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("Failed to create index on vector field! Error: " + ret.getMessage()); + } + ret = milvusServiceClient.createIndex(CreateIndexParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withFieldName("document_content") + .withIndexType(IndexType.TRIE) + .build()); + if (ret.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("Failed to create index on varchar field! Error: " + ret.getMessage()); + } + + milvusServiceClient.loadCollection(LoadCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .build()); + + final var jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription"); + final List documents = jsonReader.get(); + final var documentContent = documents.stream().map(document -> new JSONObject(Map.of("document_content", document.getContent()))).toList(); + + + var insertRet = milvusServiceClient.bulkInsert(BulkInsertParam.newBuilder().withCollectionName(COLLECTION_NAME).withFiles(List.of("classpath:/bikes.json")).build()); + + /* + R insertRet = milvusServiceClient.insert(InsertParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .withRows(documentContent) + .build()); */ + if (insertRet.getStatus() != R.Status.Success.getCode()) { + throw new RuntimeException("Failed to insert! Error: " + insertRet.getMessage()); + } + + milvusServiceClient.releaseCollection(ReleaseCollectionParam.newBuilder() + .withCollectionName(COLLECTION_NAME) + .build()); + } +} diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index a92bc0e..66181e8 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -2,6 +2,15 @@ spring: elasticsearch: uris: http://localhost:9200 ai: - azure: - openai: - api-key: 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 + openai: + api-key: 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 + chat: + options: + temperature: 0.5 + user: rubenspringai + vectorstore: + milvus: + client: + host: localhost + port: 56875 + From 92694b8c2b63074ac64e150e2ec9d85e139c366d Mon Sep 17 00:00:00 2001 From: Ruben Leibmann Date: Sat, 3 Feb 2024 16:10:24 +0100 Subject: [PATCH 3/4] RAG with Milvus Vectordatabase --- README.md | 4 + ...etrievalAugmentedGenerationController.java | 10 +- ...usRetrievalAugmentedGenerationService.java | 147 ++---------------- src/main/resources/application.yml | 4 +- 4 files changed, 23 insertions(+), 142 deletions(-) diff --git a/README.md b/README.md index 35486da..bc38ff2 100644 --- a/README.md +++ b/README.md @@ -32,6 +32,10 @@ Dies geschieht mithilfe von `docker compose -f elasticsearch/docker-compose.yml ### REST-Requests absetzen In der Datei rest-api.http sind mehrere REST-Requests dokumentiert. Diese können in IntelliJ per Klick aufgerufen werden +### Milvus Vektordatenbank starten +Damit die Anwendung mit der Vektordatenbank Milvus verwendet werden kann, muss per Docker oder minikube/helm eine Milvus instanz gestartet werden. Anschließend können host und port in der application.yml gesetzt werden. +Durch die Konfiguration von milvus host & port werden alle notwendigen Beans von Spring-AI automatisch initialisiert + ## Anfrage an ChatGPT Um eine Anfrage an ChatGPT zu stellen, kann der Endpunkt `GET localhost:8080/chatgpt?message=MeineAnfrage` aufgerufen werden. Die Anfrage wird über Spring AI direkt an ChatGPT weitergeleitet und das Ergebnis zurückgegeben. diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java index cb045fd..aa90bd7 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationController.java @@ -4,8 +4,6 @@ import org.springframework.http.ResponseEntity; import org.springframework.web.bind.annotation.*; -import java.io.IOException; - @RestController @RequestMapping("/rag/milvus") public class MilvusRetrievalAugmentedGenerationController { @@ -18,12 +16,8 @@ public MilvusRetrievalAugmentedGenerationController(MilvusRetrievalAugmentedGene @PostMapping("/ingest") public ResponseEntity ingestDocuments() { - try { - milvusRetrievalAugmentedGenerationService.ingestDocuments(); - return new ResponseEntity<>(HttpStatus.OK); - } catch (IOException e) { - return new ResponseEntity<>(HttpStatus.BAD_REQUEST); - } + milvusRetrievalAugmentedGenerationService.ingestDocuments(); + return new ResponseEntity<>(HttpStatus.OK); } @GetMapping diff --git a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java index 74d96f6..5d9bfe1 100644 --- a/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java +++ b/src/main/java/de/cofinpro/springai/retrieval_augmented_generation/milvus/MilvusRetrievalAugmentedGenerationService.java @@ -1,166 +1,49 @@ package de.cofinpro.springai.retrieval_augmented_generation.milvus; -import com.alibaba.fastjson.JSON; -import de.cofinpro.springai.retrieval_augmented_generation.AbstractRetrievalAugmentedGenerationService; -import io.milvus.client.MilvusServiceClient; -import io.milvus.common.clientenum.ConsistencyLevelEnum; -import io.milvus.grpc.CreateCollectionRequestOrBuilder; -import io.milvus.grpc.DataType; -import io.milvus.grpc.MutationResult; -import io.milvus.param.IndexType; -import io.milvus.param.MetricType; -import io.milvus.param.R; -import io.milvus.param.RpcStatus; -import io.milvus.param.bulkinsert.BulkInsertParam; -import com.alibaba.fastjson.JSONObject; -import io.milvus.param.collection.*; -import io.milvus.param.dml.InsertParam; -import io.milvus.param.dml.SearchParam; -import io.milvus.param.index.CreateIndexParam; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.chat.prompt.SystemPromptTemplate; -import org.springframework.ai.document.Document; -import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.openai.OpenAiChatClient; import org.springframework.ai.reader.JsonReader; -import org.springframework.ai.vectorstore.SimpleVectorStore; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.core.io.Resource; import org.springframework.stereotype.Service; -import org.springframework.web.bind.annotation.PostMapping; -import java.io.File; -import java.io.IOException; -import java.nio.charset.Charset; -import java.nio.charset.StandardCharsets; -import java.util.Arrays; import java.util.List; import java.util.Map; -import java.util.stream.Collectors; @Service public class MilvusRetrievalAugmentedGenerationService { - private static final File VECTORSTORE_FILE = new File("data/vectorstore.json"); - private final OpenAiChatClient openAiChatClient; - private final EmbeddingClient embeddingClient; private final Resource bikesResource; private final SystemPromptTemplate systemPromptTemplate; - private final MilvusServiceClient milvusServiceClient; - private static final String COLLECTION_NAME = "TEST_COLLECTION"; + private final VectorStore vectorStore; @Autowired - public MilvusRetrievalAugmentedGenerationService(OpenAiChatClient openAiChatClient, EmbeddingClient embeddingClient, @Value("classpath:/bikes.json") Resource bikesResource, - @Value("classpath:/system-prompt-template") Resource systemPromptTemplateResource, MilvusServiceClient milvusServiceClient) { + public MilvusRetrievalAugmentedGenerationService(OpenAiChatClient openAiChatClient, + VectorStore vectorStore, + @Value("classpath:/bikes.json") Resource bikesResource, + @Value("classpath:/system-prompt-template") Resource systemPromptTemplateResource) { this.openAiChatClient = openAiChatClient; - this.embeddingClient = embeddingClient; + this.vectorStore = vectorStore; this.bikesResource = bikesResource; - this.milvusServiceClient = milvusServiceClient; this.systemPromptTemplate = new SystemPromptTemplate(systemPromptTemplateResource); } - public String retrievalAugmentedGeneration(String message) { - List> searchVectors = List.of(embeddingClient.embed(message).stream().map(Double::floatValue).toList()); - var searchParam = SearchParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .withConsistencyLevel(ConsistencyLevelEnum.STRONG) - .withOutFields(Arrays.asList("document_content")) - .withTopK(5) - .withVectors(searchVectors) - .withVectorFieldName("document_vectors") - .withParams("{\"nprobe\":10, \"offset\":0}") - .build(); - + public void ingestDocuments() { + final var jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription"); + vectorStore.add(jsonReader.get()); + } - var status = milvusServiceClient.loadCollection(LoadCollectionParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .build()); - System.out.println(status); - final var similarDocuments = milvusServiceClient.search(searchParam); - milvusServiceClient.releaseCollection(ReleaseCollectionParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .build()); - final var results = similarDocuments.toString(); - final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", results)); + public String retrievalAugmentedGeneration(String message) { + final var test = vectorStore.similaritySearch(SearchRequest.query(message)); + final var systemMessage = systemPromptTemplate.createMessage(Map.of("documents", test)); final var prompt = new Prompt(List.of(systemMessage, new UserMessage(message))); return openAiChatClient.call(prompt).getResult().toString(); } - - public void ingestDocuments() throws IOException { - FieldType fieldType1 = FieldType.newBuilder() - .withName("document_id") - .withDataType(DataType.Int64) - .withPrimaryKey(true) - .withAutoID(true) - .build(); - FieldType fieldType2 = FieldType.newBuilder() - .withName("document_content") - .withDataType(DataType.VarChar) - .withMaxLength(5000) - .build(); - FieldType fieldType3 = FieldType.newBuilder() - .withName("document_vectors") - .withDataType(DataType.FloatVector) - .withDimension(5000) - .build(); - CreateCollectionParam createCollectionReq = CreateCollectionParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .withDescription("Test document search") - .withShardsNum(2) - .addFieldType(fieldType1) - .addFieldType(fieldType2) - .addFieldType(fieldType3) - .withEnableDynamicField(true) - .build(); - - var ret = milvusServiceClient.createCollection(createCollectionReq); - if (ret.getStatus() != R.Status.Success.getCode()) { - throw new RuntimeException("Failed to create collection! Error: " + ret.getMessage()); - } - ret = milvusServiceClient.createIndex(CreateIndexParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .withFieldName("document_vectors") - .withIndexType(IndexType.FLAT) - .withMetricType(MetricType.L2) - .build()); - if (ret.getStatus() != R.Status.Success.getCode()) { - throw new RuntimeException("Failed to create index on vector field! Error: " + ret.getMessage()); - } - ret = milvusServiceClient.createIndex(CreateIndexParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .withFieldName("document_content") - .withIndexType(IndexType.TRIE) - .build()); - if (ret.getStatus() != R.Status.Success.getCode()) { - throw new RuntimeException("Failed to create index on varchar field! Error: " + ret.getMessage()); - } - - milvusServiceClient.loadCollection(LoadCollectionParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .build()); - - final var jsonReader = new JsonReader(bikesResource, "name", "price", "shortDescription"); - final List documents = jsonReader.get(); - final var documentContent = documents.stream().map(document -> new JSONObject(Map.of("document_content", document.getContent()))).toList(); - - - var insertRet = milvusServiceClient.bulkInsert(BulkInsertParam.newBuilder().withCollectionName(COLLECTION_NAME).withFiles(List.of("classpath:/bikes.json")).build()); - - /* - R insertRet = milvusServiceClient.insert(InsertParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .withRows(documentContent) - .build()); */ - if (insertRet.getStatus() != R.Status.Success.getCode()) { - throw new RuntimeException("Failed to insert! Error: " + insertRet.getMessage()); - } - - milvusServiceClient.releaseCollection(ReleaseCollectionParam.newBuilder() - .withCollectionName(COLLECTION_NAME) - .build()); - } } diff --git a/src/main/resources/application.yml b/src/main/resources/application.yml index 66181e8..c6e605e 100644 --- a/src/main/resources/application.yml +++ b/src/main/resources/application.yml @@ -3,11 +3,11 @@ spring: uris: http://localhost:9200 ai: openai: - api-key: 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 + api-key: chat: options: temperature: 0.5 - user: rubenspringai + user: vectorstore: milvus: client: From 99037aa2cff61749f6ccdb7677359fc071a66dda Mon Sep 17 00:00:00 2001 From: Ruben Leibmann Date: Sat, 3 Feb 2024 16:11:33 +0100 Subject: [PATCH 4/4] Reorder REST examples --- rest-api.http | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/rest-api.http b/rest-api.http index 64c2114..0f7510a 100644 --- a/rest-api.http +++ b/rest-api.http @@ -5,12 +5,6 @@ GET localhost:8080/chatgpt?message=hallo ### Ingest data to simple vector store POST localhost:8080/rag/simple/ingest -### Ingest data to simple vector store -POST localhost:8080/rag/milvus/ingest - -### RAG-Query using milvus vector store -GET localhost:8080/rag/milvus?message=ultimate%20mountain%20bike - ### RAG-Query using simple vector store GET localhost:8080/rag/simple?message=ultimate%20mountain%20bike @@ -22,3 +16,11 @@ POST localhost:8080/rag/elasticsearch/ingest GET localhost:8080/rag/elasticsearch?message=ultimate%20mountain%20bike +### Ingest data to milvus vector store +POST localhost:8080/rag/milvus/ingest + +### RAG-Query using milvus vector store +GET localhost:8080/rag/milvus?message=ultimate%20mountain%20bike + + +