From a8ab84d95b1d1ed457f6ca58646946894a93075f Mon Sep 17 00:00:00 2001 From: Linyu <94553312+weijinglin@users.noreply.github.com> Date: Tue, 16 Sep 2025 14:54:03 +0800 Subject: [PATCH 1/5] refactor: refactor scheduler to support dynamic workflow scheduling and pipeline pooling (#48) --- hugegraph-llm/pyproject.toml | 2 + .../src/hugegraph_llm/flows/__init__.py | 16 ++ .../hugegraph_llm/flows/build_vector_index.py | 55 ++++ .../src/hugegraph_llm/flows/common.py | 45 +++ .../src/hugegraph_llm/flows/graph_extract.py | 127 +++++++++ .../src/hugegraph_llm/flows/scheduler.py | 90 ++++++ .../models/embeddings/init_embedding.py | 36 ++- .../src/hugegraph_llm/models/llms/init_llm.py | 80 +++++- .../operators/common_op/check_schema.py | 258 ++++++++++++++++-- .../operators/document_op/chunk_split.py | 59 ++++ .../operators/hugegraph_op/schema_manager.py | 88 +++++- .../operators/index_op/build_vector_index.py | 65 ++++- .../operators/llm_op/info_extract.py | 220 +++++++++++++-- .../llm_op/property_graph_extract.py | 190 +++++++++++-- .../src/hugegraph_llm/operators/util.py | 27 ++ .../src/hugegraph_llm/state/__init__.py | 16 ++ .../src/hugegraph_llm/state/ai_state.py | 81 ++++++ .../hugegraph_llm/utils/graph_index_utils.py | 83 ++++-- .../hugegraph_llm/utils/vector_index_utils.py | 67 +++-- 19 files changed, 1472 insertions(+), 133 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/__init__.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/common.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/scheduler.py create mode 100644 hugegraph-llm/src/hugegraph_llm/operators/util.py create mode 100644 hugegraph-llm/src/hugegraph_llm/state/__init__.py create mode 100644 hugegraph-llm/src/hugegraph_llm/state/ai_state.py diff --git a/hugegraph-llm/pyproject.toml b/hugegraph-llm/pyproject.toml index 1bd3b748c..2b0f29ace 100644 --- a/hugegraph-llm/pyproject.toml +++ b/hugegraph-llm/pyproject.toml @@ -61,6 +61,7 @@ dependencies = [ "apscheduler", "litellm", "hugegraph-python-client", + "pycgraph", ] [project.urls] homepage = "https://hugegraph.apache.org/" @@ -88,3 +89,4 @@ allow-direct-references = true [tool.uv.sources] hugegraph-python-client = { workspace = true } +pycgraph = { git = "https://github.com/ChunelFeng/CGraph.git", subdirectory = "python", rev = "main", marker = "sys_platform == 'linux'" } diff --git a/hugegraph-llm/src/hugegraph_llm/flows/__init__.py b/hugegraph-llm/src/hugegraph_llm/flows/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py new file mode 100644 index 000000000..f1ee8c1c4 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.state.ai_state import WkFlowInput + +import json +from PyCGraph import GPipeline + +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode +from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndexNode +from hugegraph_llm.state.ai_state import WkFlowState + + +class BuildVectorIndexFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, texts): + prepared_input.texts = texts + prepared_input.language = "zh" + prepared_input.split_type = "paragraph" + return + + def build_flow(self, texts): + pipeline = GPipeline() + # prepare for workflow input + prepared_input = WkFlowInput() + self.prepare(prepared_input, texts) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + chunk_split_node = ChunkSplitNode() + build_vector_node = BuildVectorIndexNode() + pipeline.registerGElement(chunk_split_node, set(), "chunk_split") + pipeline.registerGElement(build_vector_node, {chunk_split_node}, "build_vector") + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + return json.dumps(res, ensure_ascii=False, indent=2) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/common.py b/hugegraph-llm/src/hugegraph_llm/flows/common.py new file mode 100644 index 000000000..4c552626a --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/common.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from hugegraph_llm.state.ai_state import WkFlowInput + + +class BaseFlow(ABC): + """ + Base class for flows, defines three interface methods: prepare, build_flow, and post_deal. + """ + + @abstractmethod + def prepare(self, prepared_input: WkFlowInput, *args, **kwargs): + """ + Pre-processing interface. + """ + pass + + @abstractmethod + def build_flow(self, *args, **kwargs): + """ + Interface for building the flow. + """ + pass + + @abstractmethod + def post_deal(self, *args, **kwargs): + """ + Post-processing interface. + """ + pass diff --git a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py new file mode 100644 index 000000000..f1a6c5f6f --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py @@ -0,0 +1,127 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from PyCGraph import GPipeline +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.operators.common_op.check_schema import CheckSchemaNode +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManagerNode +from hugegraph_llm.operators.llm_op.info_extract import InfoExtractNode +from hugegraph_llm.operators.llm_op.property_graph_extract import ( + PropertyGraphExtractNode, +) +from hugegraph_llm.utils.log import log + + +class GraphExtractFlow(BaseFlow): + def __init__(self): + pass + + def _import_schema( + self, + from_hugegraph=None, + from_extraction=None, + from_user_defined=None, + ): + if from_hugegraph: + return SchemaManagerNode() + elif from_user_defined: + return CheckSchemaNode() + elif from_extraction: + raise NotImplementedError("Not implemented yet") + else: + raise ValueError("No input data / invalid schema type") + + def prepare( + self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type + ): + # prepare input data + prepared_input.texts = texts + prepared_input.language = "zh" + prepared_input.split_type = "document" + prepared_input.example_prompt = example_prompt + prepared_input.schema = schema + schema = schema.strip() + if schema.startswith("{"): + try: + schema = json.loads(schema) + prepared_input.schema = schema + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", schema) + prepared_input.graph_name = schema + return + + def build_flow(self, schema, texts, example_prompt, extract_type): + pipeline = GPipeline() + prepared_input = WkFlowInput() + # prepare input data + self.prepare(prepared_input, schema, texts, example_prompt, extract_type) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + schema = schema.strip() + schema_node = None + if schema.startswith("{"): + try: + schema = json.loads(schema) + schema_node = self._import_schema(from_user_defined=schema) + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", schema) + schema_node = self._import_schema(from_hugegraph=schema) + + chunk_split_node = ChunkSplitNode() + graph_extract_node = None + if extract_type == "triples": + graph_extract_node = InfoExtractNode() + elif extract_type == "property_graph": + graph_extract_node = PropertyGraphExtractNode() + else: + raise ValueError(f"Unsupported extract_type: {extract_type}") + pipeline.registerGElement(schema_node, set(), "schema_node") + pipeline.registerGElement(chunk_split_node, set(), "chunk_split") + pipeline.registerGElement( + graph_extract_node, {schema_node, chunk_split_node}, "graph_extract" + ) + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + vertices = res.get("vertices", []) + edges = res.get("edges", []) + if not vertices and not edges: + log.info("Please check the schema.(The schema may not match the Doc)") + return json.dumps( + { + "vertices": vertices, + "edges": edges, + "warning": "The schema may not match the Doc", + }, + ensure_ascii=False, + indent=2, + ) + return json.dumps( + {"vertices": vertices, "edges": edges}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py new file mode 100644 index 000000000..b096310db --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +from typing import Dict, Any +from PyCGraph import GPipelineManager +from hugegraph_llm.flows.build_vector_index import BuildVectorIndexFlow +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.flows.graph_extract import GraphExtractFlow +from hugegraph_llm.utils.log import log + + +class Scheduler: + pipeline_pool: Dict[str, Any] = None + max_pipeline: int + + def __init__(self, max_pipeline: int = 10): + self.pipeline_pool = {} + # pipeline_pool act as a manager of GPipelineManager which used for pipeline management + self.pipeline_pool["build_vector_index"] = { + "manager": GPipelineManager(), + "flow": BuildVectorIndexFlow(), + } + self.pipeline_pool["graph_extract"] = { + "manager": GPipelineManager(), + "flow": GraphExtractFlow(), + } + self.max_pipeline = max_pipeline + + # TODO: Implement Agentic Workflow + def agentic_flow(self): + pass + + def schedule_flow(self, flow: str, *args, **kwargs): + if flow not in self.pipeline_pool: + raise ValueError(f"Unsupported workflow {flow}") + manager = self.pipeline_pool[flow]["manager"] + flow: BaseFlow = self.pipeline_pool[flow]["flow"] + pipeline = manager.fetch() + if pipeline is None: + # call coresponding flow_func to create new workflow + pipeline = flow.build_flow(*args, **kwargs) + status = pipeline.init() + if status.isErr(): + error_msg = f"Error in flow init: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + status = pipeline.run() + if status.isErr(): + error_msg = f"Error in flow execution: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + res = flow.post_deal(pipeline) + manager.add(pipeline) + return res + else: + # fetch pipeline & prepare input for flow + prepared_input = pipeline.getGParamWithNoEmpty("wkflow_input") + flow.prepare(prepared_input, *args, **kwargs) + status = pipeline.run() + if status.isErr(): + raise RuntimeError(f"Error in flow execution {status.getInfo()}") + res = flow.post_deal(pipeline) + manager.release(pipeline) + return res + + +class SchedulerSingleton: + _instance = None + _instance_lock = threading.Lock() + + @classmethod + def get_instance(cls): + if cls._instance is None: + with cls._instance_lock: + if cls._instance is None: + cls._instance = Scheduler() + return cls._instance diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py index 48e4968c4..3ad50b3ec 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py @@ -17,10 +17,40 @@ from hugegraph_llm.config import llm_settings +from hugegraph_llm.config import LLMConfig from hugegraph_llm.models.embeddings.litellm import LiteLLMEmbedding from hugegraph_llm.models.embeddings.ollama import OllamaEmbedding from hugegraph_llm.models.embeddings.openai import OpenAIEmbedding +model_map = { + "openai": llm_settings.openai_embedding_model, + "ollama/local": llm_settings.ollama_embedding_model, + "litellm": llm_settings.litellm_embedding_model, +} + + +def get_embedding(llm_settings: LLMConfig): + if llm_settings.embedding_type == "openai": + return OpenAIEmbedding( + model_name=llm_settings.openai_embedding_model, + api_key=llm_settings.openai_embedding_api_key, + api_base=llm_settings.openai_embedding_api_base, + ) + if llm_settings.embedding_type == "ollama/local": + return OllamaEmbedding( + model_name=llm_settings.ollama_embedding_model, + host=llm_settings.ollama_embedding_host, + port=llm_settings.ollama_embedding_port, + ) + if llm_settings.embedding_type == "litellm": + return LiteLLMEmbedding( + model_name=llm_settings.litellm_embedding_model, + api_key=llm_settings.litellm_embedding_api_key, + api_base=llm_settings.litellm_embedding_api_base, + ) + + raise Exception("embedding type is not supported !") + class Embeddings: def __init__(self): @@ -31,19 +61,19 @@ def get_embedding(self): return OpenAIEmbedding( model_name=llm_settings.openai_embedding_model, api_key=llm_settings.openai_embedding_api_key, - api_base=llm_settings.openai_embedding_api_base + api_base=llm_settings.openai_embedding_api_base, ) if self.embedding_type == "ollama/local": return OllamaEmbedding( model_name=llm_settings.ollama_embedding_model, host=llm_settings.ollama_embedding_host, - port=llm_settings.ollama_embedding_port + port=llm_settings.ollama_embedding_port, ) if self.embedding_type == "litellm": return LiteLLMEmbedding( model_name=llm_settings.litellm_embedding_model, api_key=llm_settings.litellm_embedding_api_key, - api_base=llm_settings.litellm_embedding_api_base + api_base=llm_settings.litellm_embedding_api_base, ) raise Exception("embedding type is not supported !") diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py index e70b0d9d7..7e1eaab68 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py @@ -15,13 +15,85 @@ # specific language governing permissions and limitations # under the License. - +from hugegraph_llm.config import LLMConfig from hugegraph_llm.models.llms.ollama import OllamaClient from hugegraph_llm.models.llms.openai import OpenAIClient from hugegraph_llm.models.llms.litellm import LiteLLMClient from hugegraph_llm.config import llm_settings +def get_chat_llm(llm_settings: LLMConfig): + if llm_settings.chat_llm_type == "openai": + return OpenAIClient( + api_key=llm_settings.openai_chat_api_key, + api_base=llm_settings.openai_chat_api_base, + model_name=llm_settings.openai_chat_language_model, + max_tokens=llm_settings.openai_chat_tokens, + ) + if llm_settings.chat_llm_type == "ollama/local": + return OllamaClient( + model=llm_settings.ollama_chat_language_model, + host=llm_settings.ollama_chat_host, + port=llm_settings.ollama_chat_port, + ) + if llm_settings.chat_llm_type == "litellm": + return LiteLLMClient( + api_key=llm_settings.litellm_chat_api_key, + api_base=llm_settings.litellm_chat_api_base, + model_name=llm_settings.litellm_chat_language_model, + max_tokens=llm_settings.litellm_chat_tokens, + ) + raise Exception("chat llm type is not supported !") + + +def get_extract_llm(llm_settings: LLMConfig): + if llm_settings.extract_llm_type == "openai": + return OpenAIClient( + api_key=llm_settings.openai_extract_api_key, + api_base=llm_settings.openai_extract_api_base, + model_name=llm_settings.openai_extract_language_model, + max_tokens=llm_settings.openai_extract_tokens, + ) + if llm_settings.extract_llm_type == "ollama/local": + return OllamaClient( + model=llm_settings.ollama_extract_language_model, + host=llm_settings.ollama_extract_host, + port=llm_settings.ollama_extract_port, + ) + if llm_settings.extract_llm_type == "litellm": + return LiteLLMClient( + api_key=llm_settings.litellm_extract_api_key, + api_base=llm_settings.litellm_extract_api_base, + model_name=llm_settings.litellm_extract_language_model, + max_tokens=llm_settings.litellm_extract_tokens, + ) + raise Exception("extract llm type is not supported !") + + +def get_text2gql_llm(llm_settings: LLMConfig): + if llm_settings.text2gql_llm_type == "openai": + return OpenAIClient( + api_key=llm_settings.openai_text2gql_api_key, + api_base=llm_settings.openai_text2gql_api_base, + model_name=llm_settings.openai_text2gql_language_model, + max_tokens=llm_settings.openai_text2gql_tokens, + ) + if llm_settings.text2gql_llm_type == "ollama/local": + return OllamaClient( + model=llm_settings.ollama_text2gql_language_model, + host=llm_settings.ollama_text2gql_host, + port=llm_settings.ollama_text2gql_port, + ) + if llm_settings.text2gql_llm_type == "litellm": + return LiteLLMClient( + api_key=llm_settings.litellm_text2gql_api_key, + api_base=llm_settings.litellm_text2gql_api_base, + model_name=llm_settings.litellm_text2gql_language_model, + max_tokens=llm_settings.litellm_text2gql_tokens, + ) + raise Exception("text2gql llm type is not supported !") + + class LLMs: def __init__(self): self.chat_llm_type = llm_settings.chat_llm_type @@ -101,4 +173,8 @@ def get_text2gql_llm(self): if __name__ == "__main__": client = LLMs().get_chat_llm() print(client.generate(prompt="What is the capital of China?")) - print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}])) + print( + client.generate( + messages=[{"role": "user", "content": "What is the capital of China?"}] + ) + ) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index 3220d9f3d..7a533517a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -20,8 +20,12 @@ from hugegraph_llm.enums.property_cardinality import PropertyCardinality from hugegraph_llm.enums.property_data_type import PropertyDataType +from hugegraph_llm.operators.util import init_context from hugegraph_llm.utils.log import log +from PyCGraph import GNode, CStatus +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + def log_and_raise(message: str) -> None: log.warning(message) @@ -59,64 +63,270 @@ def _validate_schema(self, schema: Dict[str, Any]) -> None: check_type(schema, dict, "Input data is not a dictionary.") if "vertexlabels" not in schema or "edgelabels" not in schema: log_and_raise("Input data does not contain 'vertexlabels' or 'edgelabels'.") - check_type(schema["vertexlabels"], list, "'vertexlabels' in input data is not a list.") - check_type(schema["edgelabels"], list, "'edgelabels' in input data is not a list.") + check_type( + schema["vertexlabels"], list, "'vertexlabels' in input data is not a list." + ) + check_type( + schema["edgelabels"], list, "'edgelabels' in input data is not a list." + ) + + def _process_property_labels(self, schema: Dict[str, Any]) -> (list, set): + property_labels = schema.get("propertykeys", []) + check_type( + property_labels, + list, + "'propertykeys' in input data is not of correct type.", + ) + property_label_set = {label["name"] for label in property_labels} + return property_labels, property_label_set + + def _process_vertex_labels( + self, schema: Dict[str, Any], property_labels: list, property_label_set: set + ) -> None: + for vertex_label in schema["vertexlabels"]: + self._validate_vertex_label(vertex_label) + properties = vertex_label["properties"] + primary_keys = self._process_keys( + vertex_label, "primary_keys", properties[:1] + ) + if len(primary_keys) == 0: + log_and_raise(f"'primary_keys' of {vertex_label['name']} is empty.") + vertex_label["primary_keys"] = primary_keys + nullable_keys = self._process_keys( + vertex_label, "nullable_keys", properties[1:] + ) + vertex_label["nullable_keys"] = nullable_keys + self._add_missing_properties( + properties, property_labels, property_label_set + ) + + def _process_edge_labels( + self, schema: Dict[str, Any], property_labels: list, property_label_set: set + ) -> None: + for edge_label in schema["edgelabels"]: + self._validate_edge_label(edge_label) + properties = edge_label.get("properties", []) + self._add_missing_properties( + properties, property_labels, property_label_set + ) + + def _validate_vertex_label(self, vertex_label: Dict[str, Any]) -> None: + check_type(vertex_label, dict, "VertexLabel in input data is not a dictionary.") + if "name" not in vertex_label: + log_and_raise("VertexLabel in input data does not contain 'name'.") + check_type( + vertex_label["name"], str, "'name' in vertex_label is not of correct type." + ) + if "properties" not in vertex_label: + log_and_raise("VertexLabel in input data does not contain 'properties'.") + check_type( + vertex_label["properties"], + list, + "'properties' in vertex_label is not of correct type.", + ) + if len(vertex_label["properties"]) == 0: + log_and_raise("'properties' in vertex_label is empty.") + + def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: + check_type(edge_label, dict, "EdgeLabel in input data is not a dictionary.") + if ( + "name" not in edge_label + or "source_label" not in edge_label + or "target_label" not in edge_label + ): + log_and_raise( + "EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'." + ) + check_type( + edge_label["name"], str, "'name' in edge_label is not of correct type." + ) + check_type( + edge_label["source_label"], + str, + "'source_label' in edge_label is not of correct type.", + ) + check_type( + edge_label["target_label"], + str, + "'target_label' in edge_label is not of correct type.", + ) + + def _process_keys( + self, label: Dict[str, Any], key_type: str, default_keys: list + ) -> list: + keys = label.get(key_type, default_keys) + check_type( + keys, list, f"'{key_type}' in {label['name']} is not of correct type." + ) + new_keys = [key for key in keys if key in label["properties"]] + return new_keys + + def _add_missing_properties( + self, properties: list, property_labels: list, property_label_set: set + ) -> None: + for prop in properties: + if prop not in property_label_set: + property_labels.append( + { + "name": prop, + "data_type": PropertyDataType.DEFAULT.value, + "cardinality": PropertyCardinality.DEFAULT.value, + } + ) + property_label_set.add(prop) + + +class CheckSchemaNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + if self.wk_input.schema is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + self.data = self.wk_input.schema + return CStatus() + + def run(self) -> CStatus: + # init workflow input + sts = self.node_init() + if sts.isErr(): + return sts + # 1. Validate the schema structure + self.context.lock() + schema = self.data or self.context.schema + self._validate_schema(schema) + # 2. Process property labels and also create a set for it + property_labels, property_label_set = self._process_property_labels(schema) + # 3. Process properties in given vertex/edge labels + self._process_vertex_labels(schema, property_labels, property_label_set) + self._process_edge_labels(schema, property_labels, property_label_set) + # 4. Update schema with processed pks + schema["propertykeys"] = property_labels + self.context.schema = schema + self.context.unlock() + return CStatus() + + def _validate_schema(self, schema: Dict[str, Any]) -> None: + check_type(schema, dict, "Input data is not a dictionary.") + if "vertexlabels" not in schema or "edgelabels" not in schema: + log_and_raise("Input data does not contain 'vertexlabels' or 'edgelabels'.") + check_type( + schema["vertexlabels"], list, "'vertexlabels' in input data is not a list." + ) + check_type( + schema["edgelabels"], list, "'edgelabels' in input data is not a list." + ) def _process_property_labels(self, schema: Dict[str, Any]) -> (list, set): property_labels = schema.get("propertykeys", []) - check_type(property_labels, list, "'propertykeys' in input data is not of correct type.") + check_type( + property_labels, + list, + "'propertykeys' in input data is not of correct type.", + ) property_label_set = {label["name"] for label in property_labels} return property_labels, property_label_set - def _process_vertex_labels(self, schema: Dict[str, Any], property_labels: list, property_label_set: set) -> None: + def _process_vertex_labels( + self, schema: Dict[str, Any], property_labels: list, property_label_set: set + ) -> None: for vertex_label in schema["vertexlabels"]: self._validate_vertex_label(vertex_label) properties = vertex_label["properties"] - primary_keys = self._process_keys(vertex_label, "primary_keys", properties[:1]) + primary_keys = self._process_keys( + vertex_label, "primary_keys", properties[:1] + ) if len(primary_keys) == 0: log_and_raise(f"'primary_keys' of {vertex_label['name']} is empty.") vertex_label["primary_keys"] = primary_keys - nullable_keys = self._process_keys(vertex_label, "nullable_keys", properties[1:]) + nullable_keys = self._process_keys( + vertex_label, "nullable_keys", properties[1:] + ) vertex_label["nullable_keys"] = nullable_keys - self._add_missing_properties(properties, property_labels, property_label_set) + self._add_missing_properties( + properties, property_labels, property_label_set + ) - def _process_edge_labels(self, schema: Dict[str, Any], property_labels: list, property_label_set: set) -> None: + def _process_edge_labels( + self, schema: Dict[str, Any], property_labels: list, property_label_set: set + ) -> None: for edge_label in schema["edgelabels"]: self._validate_edge_label(edge_label) properties = edge_label.get("properties", []) - self._add_missing_properties(properties, property_labels, property_label_set) + self._add_missing_properties( + properties, property_labels, property_label_set + ) def _validate_vertex_label(self, vertex_label: Dict[str, Any]) -> None: check_type(vertex_label, dict, "VertexLabel in input data is not a dictionary.") if "name" not in vertex_label: log_and_raise("VertexLabel in input data does not contain 'name'.") - check_type(vertex_label["name"], str, "'name' in vertex_label is not of correct type.") + check_type( + vertex_label["name"], str, "'name' in vertex_label is not of correct type." + ) if "properties" not in vertex_label: log_and_raise("VertexLabel in input data does not contain 'properties'.") - check_type(vertex_label["properties"], list, "'properties' in vertex_label is not of correct type.") + check_type( + vertex_label["properties"], + list, + "'properties' in vertex_label is not of correct type.", + ) if len(vertex_label["properties"]) == 0: log_and_raise("'properties' in vertex_label is empty.") def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: check_type(edge_label, dict, "EdgeLabel in input data is not a dictionary.") - if "name" not in edge_label or "source_label" not in edge_label or "target_label" not in edge_label: - log_and_raise("EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'.") - check_type(edge_label["name"], str, "'name' in edge_label is not of correct type.") - check_type(edge_label["source_label"], str, "'source_label' in edge_label is not of correct type.") - check_type(edge_label["target_label"], str, "'target_label' in edge_label is not of correct type.") + if ( + "name" not in edge_label + or "source_label" not in edge_label + or "target_label" not in edge_label + ): + log_and_raise( + "EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'." + ) + check_type( + edge_label["name"], str, "'name' in edge_label is not of correct type." + ) + check_type( + edge_label["source_label"], + str, + "'source_label' in edge_label is not of correct type.", + ) + check_type( + edge_label["target_label"], + str, + "'target_label' in edge_label is not of correct type.", + ) - def _process_keys(self, label: Dict[str, Any], key_type: str, default_keys: list) -> list: + def _process_keys( + self, label: Dict[str, Any], key_type: str, default_keys: list + ) -> list: keys = label.get(key_type, default_keys) - check_type(keys, list, f"'{key_type}' in {label['name']} is not of correct type.") + check_type( + keys, list, f"'{key_type}' in {label['name']} is not of correct type." + ) new_keys = [key for key in keys if key in label["properties"]] return new_keys - def _add_missing_properties(self, properties: list, property_labels: list, property_label_set: set) -> None: + def _add_missing_properties( + self, properties: list, property_labels: list, property_label_set: set + ) -> None: for prop in properties: if prop not in property_label_set: - property_labels.append({ - "name": prop, - "data_type": PropertyDataType.DEFAULT.value, - "cardinality": PropertyCardinality.DEFAULT.value, - }) + property_labels.append( + { + "name": prop, + "data_type": PropertyDataType.DEFAULT.value, + "cardinality": PropertyCardinality.DEFAULT.value, + } + ) property_label_set.add(prop) + + def get_result(self): + self.context.lock() + res = self.context.to_json() + self.context.unlock() + return res diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py index 8c2dd80f5..d779a40ab 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py @@ -19,6 +19,8 @@ from typing import Literal, Dict, Any, Optional, Union, List from langchain_text_splitters import RecursiveCharacterTextSplitter +from hugegraph_llm.operators.util import init_context +from PyCGraph import GNode, CStatus # Constants LANGUAGE_ZH = "zh" @@ -27,6 +29,63 @@ SPLIT_TYPE_PARAGRAPH = "paragraph" SPLIT_TYPE_SENTENCE = "sentence" + +class ChunkSplitNode(GNode): + def init(self): + return init_context(self) + + def node_init(self): + if ( + self.wk_input.texts is None + or self.wk_input.language is None + or self.wk_input.split_type is None + ): + return CStatus(-1, "Error occurs when prepare for workflow input") + texts = self.wk_input.texts + language = self.wk_input.language + split_type = self.wk_input.split_type + if isinstance(texts, str): + texts = [texts] + self.texts = texts + self.separators = self._get_separators(language) + self.text_splitter = self._get_text_splitter(split_type) + return CStatus() + + def _get_separators(self, language: str) -> List[str]: + if language == LANGUAGE_ZH: + return ["\n\n", "\n", "。", ",", ""] + if language == LANGUAGE_EN: + return ["\n\n", "\n", ".", ",", " ", ""] + raise ValueError("language must be zh or en") + + def _get_text_splitter(self, split_type: str): + if split_type == SPLIT_TYPE_DOCUMENT: + return lambda text: [text] + if split_type == SPLIT_TYPE_PARAGRAPH: + return RecursiveCharacterTextSplitter( + chunk_size=500, chunk_overlap=30, separators=self.separators + ).split_text + if split_type == SPLIT_TYPE_SENTENCE: + return RecursiveCharacterTextSplitter( + chunk_size=50, chunk_overlap=0, separators=self.separators + ).split_text + raise ValueError("Type must be document, paragraph or sentence") + + def run(self): + sts = self.node_init() + if sts.isErr(): + return sts + all_chunks = [] + for text in self.texts: + chunks = self.text_splitter(text) + all_chunks.extend(chunks) + + self.context.lock() + self.context.chunks = all_chunks + self.context.unlock() + return CStatus() + + class ChunkSplit: def __init__( self, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py index 2f50bb818..670c18b4a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -17,8 +17,12 @@ from typing import Dict, Any, Optional from hugegraph_llm.config import huge_settings +from hugegraph_llm.operators.util import init_context +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from pyhugegraph.client import PyHugeClient +from PyCGraph import GNode, CStatus + class SchemaManager: def __init__(self, graph_name: str): @@ -39,15 +43,22 @@ def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: if "vertexlabels" in schema: mini_schema["vertexlabels"] = [] for vertex in schema["vertexlabels"]: - new_vertex = {key: vertex[key] for key in ["id", "name", "properties"] if key in vertex} + new_vertex = { + key: vertex[key] + for key in ["id", "name", "properties"] + if key in vertex + } mini_schema["vertexlabels"].append(new_vertex) # Add necessary edgelabels items (4) if "edgelabels" in schema: mini_schema["edgelabels"] = [] for edge in schema["edgelabels"]: - new_edge = {key: edge[key] for key in - ["name", "source_label", "target_label", "properties"] if key in edge} + new_edge = { + key: edge[key] + for key in ["name", "source_label", "target_label", "properties"] + if key in edge + } mini_schema["edgelabels"].append(new_edge) return mini_schema @@ -63,3 +74,74 @@ def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]: # TODO: enhance the logic here context["simple_schema"] = self.simple_schema(schema) return context + + +class SchemaManagerNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + if self.wk_input.graph_name is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + graph_name = self.wk_input.graph_name + self.graph_name = graph_name + self.client = PyHugeClient( + url=huge_settings.graph_url, + graph=self.graph_name, + user=huge_settings.graph_user, + pwd=huge_settings.graph_pwd, + graphspace=huge_settings.graph_space, + ) + self.schema = self.client.schema() + return CStatus() + + def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: + mini_schema = {} + + # Add necessary vertexlabels items (3) + if "vertexlabels" in schema: + mini_schema["vertexlabels"] = [] + for vertex in schema["vertexlabels"]: + new_vertex = { + key: vertex[key] + for key in ["id", "name", "properties"] + if key in vertex + } + mini_schema["vertexlabels"].append(new_vertex) + + # Add necessary edgelabels items (4) + if "edgelabels" in schema: + mini_schema["edgelabels"] = [] + for edge in schema["edgelabels"]: + new_edge = { + key: edge[key] + for key in ["name", "source_label", "target_label", "properties"] + if key in edge + } + mini_schema["edgelabels"].append(new_edge) + + return mini_schema + + def run(self) -> CStatus: + sts = self.node_init() + if sts.isErr(): + return sts + schema = self.schema.getSchema() + if not schema["vertexlabels"] and not schema["edgelabels"]: + raise Exception(f"Can not get {self.graph_name}'s schema from HugeGraph!") + + self.context.lock() + self.context.schema = schema + # TODO: enhance the logic here + self.context.simple_schema = self.simple_schema(schema) + self.context.unlock() + return CStatus() + + def get_result(self): + self.context.lock() + res = self.context.to_json() + self.context.unlock() + return res diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py index ffb35564b..ee89d330f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py @@ -23,20 +23,75 @@ from hugegraph_llm.config import huge_settings, resource_path, llm_settings from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel, get_filename_prefix, get_index_folder_name +from hugegraph_llm.utils.embedding_utils import ( + get_embeddings_parallel, + get_filename_prefix, + get_index_folder_name, +) from hugegraph_llm.utils.log import log +from hugegraph_llm.operators.util import init_context +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from PyCGraph import GNode, CStatus + + +class BuildVectorIndexNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + self.embedding = get_embedding(llm_settings) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + self.index_dir = str(os.path.join(resource_path, self.folder_name, "chunks")) + self.filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(self.embedding, "model_name", None) + ) + self.vector_index = VectorIndex.from_index_file( + self.index_dir, self.filename_prefix + ) + return CStatus() + + def run(self): + # init workflow input + sts = self.node_init() + if sts.isErr(): + return sts + self.context.lock() + try: + if self.context.chunks is None: + raise ValueError("chunks not found in context.") + chunks = self.context.chunks + finally: + self.context.unlock() + chunks_embedding = [] + log.debug("Building vector index for %s chunks...", len(chunks)) + # TODO: use async_get_texts_embedding instead of single sync method + chunks_embedding = asyncio.run(get_embeddings_parallel(self.embedding, chunks)) + if len(chunks_embedding) > 0: + self.vector_index.add(chunks_embedding, chunks) + self.vector_index.to_index_file(self.index_dir, self.filename_prefix) + return CStatus() + class BuildVectorIndex: def __init__(self, embedding: BaseEmbedding): self.embedding = embedding - self.folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) self.index_dir = str(os.path.join(resource_path, self.folder_name, "chunks")) self.filename_prefix = get_filename_prefix( - llm_settings.embedding_type, - getattr(self.embedding, "model_name", None) + llm_settings.embedding_type, getattr(self.embedding, "model_name", None) + ) + self.vector_index = VectorIndex.from_index_file( + self.index_dir, self.filename_prefix ) - self.vector_index = VectorIndex.from_index_file(self.index_dir, self.filename_prefix) def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if "chunks" not in context: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 42bb6b108..15a8fdda7 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -18,20 +18,26 @@ import re from typing import List, Any, Dict, Optional +from hugegraph_llm.config import llm_settings from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log +from hugegraph_llm.operators.util import init_context +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from PyCGraph import GNode, CStatus + SCHEMA_EXAMPLE_PROMPT = """## Main Task Extract Triples from the given text and graph schema ## Basic Rules 1. The output format must be: (X,Y,Z) - LABEL -In this format, Y must be a value from "properties" or "edge_label", +In this format, Y must be a value from "properties" or "edge_label", and LABEL must be X's vertex_label or Y's edge_label. 2. Don't extract attribute/property fields that do not exist in the given schema 3. Ensure the extract property is in the same type as the schema (like 'age' should be a number) -4. Translate the given schema filed into Chinese if the given text is Chinese but the schema is in English (Optional) +4. Translate the given schema filed into Chinese if the given text is Chinese but the schema is in English (Optional) ## Example (Note: Update the example to correspond to the given text and schema) ### Input example: @@ -75,8 +81,10 @@ def generate_extract_triple_prompt(text, schema=None) -> str: if schema: return schema_real_prompt - log.warning("Recommend to provide a graph schema to improve the extraction accuracy. " - "Now using the default schema.") + log.warning( + "Recommend to provide a graph schema to improve the extraction accuracy. " + "Now using the default schema." + ) return text_based_prompt @@ -105,11 +113,17 @@ def extract_triples_by_regex_with_schema(schema, text, graph): # TODO: use a more efficient way to compare the extract & input property p_lower = p.lower() for vertex in schema["vertices"]: - if vertex["vertex_label"] == label and any(pp.lower() == p_lower - for pp in vertex["properties"]): + if vertex["vertex_label"] == label and any( + pp.lower() == p_lower for pp in vertex["properties"] + ): id = f"{label}-{s}" if id not in vertices_dict: - vertices_dict[id] = {"id": id, "name": s, "label": label, "properties": {p: o}} + vertices_dict[id] = { + "id": id, + "name": s, + "label": label, + "properties": {p: o}, + } else: vertices_dict[id]["properties"].update({p: o}) break @@ -118,25 +132,35 @@ def extract_triples_by_regex_with_schema(schema, text, graph): source_label = edge["source_vertex_label"] source_id = f"{source_label}-{s}" if source_id not in vertices_dict: - vertices_dict[source_id] = {"id": source_id, "name": s, "label": source_label, - "properties": {}} + vertices_dict[source_id] = { + "id": source_id, + "name": s, + "label": source_label, + "properties": {}, + } target_label = edge["target_vertex_label"] target_id = f"{target_label}-{o}" if target_id not in vertices_dict: - vertices_dict[target_id] = {"id": target_id, "name": o, "label": target_label, - "properties": {}} - graph["edges"].append({"start": source_id, "end": target_id, "type": label, - "properties": {}}) + vertices_dict[target_id] = { + "id": target_id, + "name": o, + "label": target_label, + "properties": {}, + } + graph["edges"].append( + { + "start": source_id, + "end": target_id, + "type": label, + "properties": {}, + } + ) break - graph["vertices"] = vertices_dict.values() + graph["vertices"] = list(vertices_dict.values()) class InfoExtract: - def __init__( - self, - llm: BaseLLM, - example_prompt: Optional[str] = None - ) -> None: + def __init__(self, llm: BaseLLM, example_prompt: Optional[str] = None) -> None: self.llm = llm self.example_prompt = example_prompt @@ -152,7 +176,12 @@ def run(self, context: Dict[str, Any]) -> Dict[str, List[Any]]: for sentence in chunks: proceeded_chunk = self.extract_triples_by_llm(schema, sentence) - log.debug("[Legacy] %s input: %s \n output:%s", self.__class__.__name__, sentence, proceeded_chunk) + log.debug( + "[Legacy] %s input: %s \n output:%s", + self.__class__.__name__, + sentence, + proceeded_chunk, + ) if schema: extract_triples_by_regex_with_schema(schema, proceeded_chunk, context) else: @@ -175,7 +204,152 @@ def valid(self, element_id: str, max_length: int = 256) -> bool: return True def _filter_long_id(self, graph) -> Dict[str, List[Any]]: - graph["vertices"] = [vertex for vertex in graph["vertices"] if self.valid(vertex["id"])] - graph["edges"] = [edge for edge in graph["edges"] - if self.valid(edge["start"]) and self.valid(edge["end"])] + graph["vertices"] = [ + vertex for vertex in graph["vertices"] if self.valid(vertex["id"]) + ] + graph["edges"] = [ + edge + for edge in graph["edges"] + if self.valid(edge["start"]) and self.valid(edge["end"]) + ] return graph + + +class InfoExtractNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + self.llm = get_chat_llm(llm_settings) + if self.wk_input.example_prompt is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + self.example_prompt = self.wk_input.example_prompt + return CStatus() + + def extract_triples_by_regex_with_schema(self, schema, text): + text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") + pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)" + matches = re.findall(pattern, text) + + vertices_dict = {v["id"]: v for v in self.context.vertices} + for match in matches: + s, p, o, label = [item.strip() for item in match] + if None in [label, s, p, o]: + continue + # TODO: use a more efficient way to compare the extract & input property + p_lower = p.lower() + for vertex in schema["vertices"]: + if vertex["vertex_label"] == label and any( + pp.lower() == p_lower for pp in vertex["properties"] + ): + id = f"{label}-{s}" + if id not in vertices_dict: + vertices_dict[id] = { + "id": id, + "name": s, + "label": label, + "properties": {p: o}, + } + else: + vertices_dict[id]["properties"].update({p: o}) + break + for edge in schema["edges"]: + if edge["edge_label"] == label: + source_label = edge["source_vertex_label"] + source_id = f"{source_label}-{s}" + if source_id not in vertices_dict: + vertices_dict[source_id] = { + "id": source_id, + "name": s, + "label": source_label, + "properties": {}, + } + target_label = edge["target_vertex_label"] + target_id = f"{target_label}-{o}" + if target_id not in vertices_dict: + vertices_dict[target_id] = { + "id": target_id, + "name": o, + "label": target_label, + "properties": {}, + } + self.context.edges.append( + { + "start": source_id, + "end": target_id, + "type": label, + "properties": {}, + } + ) + break + self.context.vertices = list(vertices_dict.values()) + + def extract_triples_by_regex(self, text): + text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") + pattern = r"\((.*?), (.*?), (.*?)\)" + self.context.triples += re.findall(pattern, text) + + def run(self) -> CStatus: + sts = self.node_init() + if sts.isErr(): + return sts + self.context.lock() + if self.context.chunks is None: + self.context.unlock() + raise ValueError("parameter required by extract node not found in context.") + schema = self.context.schema + chunks = self.context.chunks + + if schema: + self.context.vertices = [] + self.context.edges = [] + else: + self.context.triples = [] + + self.context.unlock() + + for sentence in chunks: + proceeded_chunk = self.extract_triples_by_llm(schema, sentence) + log.debug( + "[Legacy] %s input: %s \n output:%s", + self.__class__.__name__, + sentence, + proceeded_chunk, + ) + if schema: + self.extract_triples_by_regex_with_schema(schema, proceeded_chunk) + else: + self.extract_triples_by_regex(proceeded_chunk) + + if self.context.call_count: + self.context.call_count += len(chunks) + else: + self.context.call_count = len(chunks) + self._filter_long_id() + return CStatus() + + def extract_triples_by_llm(self, schema, chunk) -> str: + prompt = generate_extract_triple_prompt(chunk, schema) + if self.example_prompt is not None: + prompt = self.example_prompt + prompt + return self.llm.generate(prompt=prompt) + + # TODO: make 'max_length' be a configurable param in settings.py/settings.cfg + def valid(self, element_id: str, max_length: int = 256) -> bool: + if len(element_id.encode("utf-8")) >= max_length: + log.warning("Filter out GraphElementID too long: %s", element_id) + return False + return True + + def _filter_long_id(self): + self.context.vertices = [ + vertex for vertex in self.context.vertices if self.valid(vertex["id"]) + ] + self.context.edges = [ + edge + for edge in self.context.edges + if self.valid(edge["start"]) and self.valid(edge["end"]) + ] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py index faff1c6b2..6e492b8f5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py @@ -21,16 +21,19 @@ import re from typing import List, Any, Dict -from hugegraph_llm.config import prompt +from hugegraph_llm.config import llm_settings, prompt from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log -""" -TODO: It is not clear whether there is any other dependence on the SCHEMA_EXAMPLE_PROMPT variable. -Because the SCHEMA_EXAMPLE_PROMPT variable will no longer change based on -prompt.extract_graph_prompt changes after the system loads, this does not seem to meet expectations. -""" +from hugegraph_llm.operators.util import init_context +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.state.ai_state import WkFlowState, WkFlowInput +from PyCGraph import GNode, CStatus + +# TODO: It is not clear whether there is any other dependence on the SCHEMA_EXAMPLE_PROMPT variable. +# Because the SCHEMA_EXAMPLE_PROMPT variable will no longer change based on +# prompt.extract_graph_prompt changes after the system loads, this does not seem to meet expectations. SCHEMA_EXAMPLE_PROMPT = prompt.extract_graph_prompt @@ -60,20 +63,18 @@ def filter_item(schema, items) -> List[Dict[str, Any]]: properties_map["vertex"][vertex["name"]] = { "primary_keys": vertex["primary_keys"], "nullable_keys": vertex["nullable_keys"], - "properties": vertex["properties"] + "properties": vertex["properties"], } for edge in schema["edgelabels"]: - properties_map["edge"][edge["name"]] = { - "properties": edge["properties"] - } + properties_map["edge"][edge["name"]] = {"properties": edge["properties"]} log.info("properties_map: %s", properties_map) for item in items: item_type = item["type"] if item_type == "vertex": label = item["label"] - non_nullable_keys = ( - set(properties_map[item_type][label]["properties"]) - .difference(set(properties_map[item_type][label]["nullable_keys"]))) + non_nullable_keys = set( + properties_map[item_type][label]["properties"] + ).difference(set(properties_map[item_type][label]["nullable_keys"])) for key in non_nullable_keys: if key not in item["properties"]: item["properties"][key] = "NULL" @@ -87,9 +88,7 @@ def filter_item(schema, items) -> List[Dict[str, Any]]: class PropertyGraphExtract: def __init__( - self, - llm: BaseLLM, - example_prompt: str = prompt.extract_graph_prompt + self, llm: BaseLLM, example_prompt: str = prompt.extract_graph_prompt ) -> None: self.llm = llm self.example_prompt = example_prompt @@ -105,7 +104,12 @@ def run(self, context: Dict[str, Any]) -> Dict[str, List[Any]]: items = [] for chunk in chunks: proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk) - log.debug("[LLM] %s input: %s \n output:%s", self.__class__.__name__, chunk, proceeded_chunk) + log.debug( + "[LLM] %s input: %s \n output:%s", + self.__class__.__name__, + chunk, + proceeded_chunk, + ) items.extend(self._extract_and_filter_label(schema, proceeded_chunk)) items = filter_item(schema, items) for item in items: @@ -125,10 +129,132 @@ def extract_property_graph_by_llm(self, schema, chunk): def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: # Use regex to extract a JSON object with curly braces - json_match = re.search(r'({.*})', text, re.DOTALL) + json_match = re.search(r"({.*})", text, re.DOTALL) + if not json_match: + log.critical( + "Invalid property graph! No JSON object found, " + "please check the output format example in prompt." + ) + return [] + json_str = json_match.group(1).strip() + + items = [] + try: + property_graph = json.loads(json_str) + # Expect property_graph to be a dict with keys "vertices" and "edges" + if not ( + isinstance(property_graph, dict) + and "vertices" in property_graph + and "edges" in property_graph + ): + log.critical( + "Invalid property graph format; expecting 'vertices' and 'edges'." + ) + return items + + # Create sets for valid vertex and edge labels based on the schema + vertex_label_set = {vertex["name"] for vertex in schema["vertexlabels"]} + edge_label_set = {edge["name"] for edge in schema["edgelabels"]} + + def process_items(item_list, valid_labels, item_type): + for item in item_list: + if not isinstance(item, dict): + log.warning( + "Invalid property graph item type '%s'.", type(item) + ) + continue + if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()): + log.warning("Invalid item keys '%s'.", item.keys()) + continue + if item["label"] not in valid_labels: + log.warning( + "Invalid %s label '%s' has been ignored.", + item_type, + item["label"], + ) + continue + items.append(item) + + process_items(property_graph["vertices"], vertex_label_set, "vertex") + process_items(property_graph["edges"], edge_label_set, "edge") + except json.JSONDecodeError: + log.critical( + "Invalid property graph JSON! Please check the extracted JSON data carefully" + ) + return items + + +class PropertyGraphExtractNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + self.NECESSARY_ITEM_KEYS = {"label", "type", "properties"} # pylint: disable=invalid-name + return init_context(self) + + def node_init(self): + self.llm = get_chat_llm(llm_settings) + if self.wk_input.example_prompt is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + self.example_prompt = self.wk_input.example_prompt + return CStatus() + + def run(self) -> CStatus: + sts = self.node_init() + if sts.isErr(): + return sts + self.context.lock() + try: + if self.context.schema is None or self.context.chunks is None: + raise ValueError( + "parameter required by extract node not found in context." + ) + schema = self.context.schema + chunks = self.context.chunks + if self.context.vertices is None: + self.context.vertices = [] + if self.context.edges is None: + self.context.edges = [] + finally: + self.context.unlock() + + items = [] + for chunk in chunks: + proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk) + log.debug( + "[LLM] %s input: %s \n output:%s", + self.__class__.__name__, + chunk, + proceeded_chunk, + ) + items.extend(self._extract_and_filter_label(schema, proceeded_chunk)) + items = filter_item(schema, items) + self.context.lock() + try: + for item in items: + if item["type"] == "vertex": + self.context.vertices.append(item) + elif item["type"] == "edge": + self.context.edges.append(item) + finally: + self.context.unlock() + self.context.call_count = (self.context.call_count or 0) + len(chunks) + return CStatus() + + def extract_property_graph_by_llm(self, schema, chunk): + prompt = generate_extract_property_graph_prompt(chunk, schema) + if self.example_prompt is not None: + prompt = self.example_prompt + prompt + return self.llm.generate(prompt=prompt) + + def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: + # Use regex to extract a JSON object with curly braces + json_match = re.search(r"({.*})", text, re.DOTALL) if not json_match: - log.critical("Invalid property graph! No JSON object found, " - "please check the output format example in prompt.") + log.critical( + "Invalid property graph! No JSON object found, " + "please check the output format example in prompt." + ) return [] json_str = json_match.group(1).strip() @@ -136,8 +262,14 @@ def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: try: property_graph = json.loads(json_str) # Expect property_graph to be a dict with keys "vertices" and "edges" - if not (isinstance(property_graph, dict) and "vertices" in property_graph and "edges" in property_graph): - log.critical("Invalid property graph format; expecting 'vertices' and 'edges'.") + if not ( + isinstance(property_graph, dict) + and "vertices" in property_graph + and "edges" in property_graph + ): + log.critical( + "Invalid property graph format; expecting 'vertices' and 'edges'." + ) return items # Create sets for valid vertex and edge labels based on the schema @@ -147,18 +279,26 @@ def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: def process_items(item_list, valid_labels, item_type): for item in item_list: if not isinstance(item, dict): - log.warning("Invalid property graph item type '%s'.", type(item)) + log.warning( + "Invalid property graph item type '%s'.", type(item) + ) continue if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()): log.warning("Invalid item keys '%s'.", item.keys()) continue if item["label"] not in valid_labels: - log.warning("Invalid %s label '%s' has been ignored.", item_type, item["label"]) + log.warning( + "Invalid %s label '%s' has been ignored.", + item_type, + item["label"], + ) continue items.append(item) process_items(property_graph["vertices"], vertex_label_set, "vertex") process_items(property_graph["edges"], edge_label_set, "edge") except json.JSONDecodeError: - log.critical("Invalid property graph JSON! Please check the extracted JSON data carefully") + log.critical( + "Invalid property graph JSON! Please check the extracted JSON data carefully" + ) return items diff --git a/hugegraph-llm/src/hugegraph_llm/operators/util.py b/hugegraph-llm/src/hugegraph_llm/operators/util.py new file mode 100644 index 000000000..60bdc2e86 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/operators/util.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus + + +def init_context(obj) -> CStatus: + try: + obj.context = obj.getGParamWithNoEmpty("wkflow_state") + obj.wk_input = obj.getGParamWithNoEmpty("wkflow_input") + if obj.context is None or obj.wk_input is None: + return CStatus(-1, "Required workflow parameters not found") + return CStatus() + except Exception as e: + return CStatus(-1, f"Failed to initialize context: {str(e)}") diff --git a/hugegraph-llm/src/hugegraph_llm/state/__init__.py b/hugegraph-llm/src/hugegraph_llm/state/__init__.py new file mode 100644 index 000000000..13a83393a --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/state/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py new file mode 100644 index 000000000..0543aa2b4 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import GParam, CStatus + +from typing import Union, List, Optional, Any + + +class WkFlowInput(GParam): + texts: Union[str, List[str]] = None # texts input used by ChunkSplit Node + language: str = None # language configuration used by ChunkSplit Node + split_type: str = None # split type used by ChunkSplit Node + example_prompt: str = None # need by graph information extract + schema: str = None # Schema information requeired by SchemaNode + graph_name: str = None + + def reset(self, _: CStatus) -> None: + self.texts = None + self.language = None + self.split_type = None + self.example_prompt = None + self.schema = None + self.graph_name = None + + +class WkFlowState(GParam): + schema: Optional[str] = None # schema message + simple_schema: Optional[str] = None + chunks: Optional[List[str]] = None + edges: Optional[List[Any]] = None + vertices: Optional[List[Any]] = None + triples: Optional[List[Any]] = None + call_count: Optional[int] = None + + keywords: Optional[List[str]] = None + vector_result = None + graph_result = None + keywords_embeddings = None + + def setup(self): + self.schema = None + self.simple_schema = None + self.chunks = None + self.edges = None + self.vertices = None + self.triples = None + self.call_count = None + + self.keywords = None + self.vector_result = None + self.graph_result = None + self.keywords_embeddings = None + + return CStatus() + + def to_json(self): + """ + Automatically returns a JSON-formatted dictionary of all non-None instance members, + eliminating the need to manually maintain the member list. + + Returns: + dict: A dictionary containing non-None instance members and their serialized values. + """ + # Only export instance attributes (excluding methods and class attributes) whose values are not None + return { + k: v + for k, v in self.__dict__.items() + if not k.startswith("_") and v is not None + } diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index 9fef06d2b..f61b5f843 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -22,6 +22,7 @@ from typing import Dict, Any, Union, Optional import gradio as gr +from hugegraph_llm.flows.scheduler import SchedulerSingleton from .embedding_utils import get_filename_prefix, get_index_folder_name from .hugegraph_utils import get_hg_client, clean_hg_data @@ -35,11 +36,17 @@ def get_graph_index_info(): - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) graph_summary_info = builder.fetch_graph_data().run() - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) - filename_prefix = get_filename_prefix(llm_settings.embedding_type, getattr(builder.embedding, "model_name", None)) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(builder.embedding, "model_name", None) + ) vector_index = VectorIndex.from_index_file(index_dir, filename_prefix) graph_summary_info["vid_index"] = { "embed_dim": vector_index.index.d, @@ -50,15 +57,20 @@ def get_graph_index_info(): def clean_all_graph_index(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) - filename_prefix = get_filename_prefix(llm_settings.embedding_type, - getattr(Embeddings().get_embedding(), "model_name", None)) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, + getattr(Embeddings().get_embedding(), "model_name", None), + ) VectorIndex.clean( - str(os.path.join(resource_path, folder_name, "graph_vids")), - filename_prefix) + str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix + ) VectorIndex.clean( str(os.path.join(resource_path, folder_name, "gremlin_examples")), - filename_prefix) + filename_prefix, + ) log.warning("Clear graph index and text2gql index successfully!") gr.Info("Clear graph index and text2gql index successfully!") @@ -71,7 +83,7 @@ def clean_all_graph_data(): def parse_schema(schema: str, builder: KgBuilder) -> Optional[str]: schema = schema.strip() - if schema.startswith('{'): + if schema.startswith("{"): try: schema = json.loads(schema) builder.import_schema(from_user_defined=schema) @@ -84,16 +96,20 @@ def parse_schema(schema: str, builder: KgBuilder) -> Optional[str]: return None -def extract_graph(input_file, input_text, schema, example_prompt) -> str: +def extract_graph_origin(input_file, input_text, schema, example_prompt) -> str: texts = read_documents(input_file, input_text) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) if not schema: return "ERROR: please input with correct schema/format." error_message = parse_schema(schema, builder) if error_message: return error_message - builder.chunk_split(texts, "document", "zh").extract_info(example_prompt, "property_graph") + builder.chunk_split(texts, "document", "zh").extract_info( + example_prompt, "property_graph" + ) try: context = builder.run() @@ -103,19 +119,40 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: { "vertices": context["vertices"], "edges": context["edges"], - "warning": "The schema may not match the Doc" + "warning": "The schema may not match the Doc", }, ensure_ascii=False, - indent=2 + indent=2, ) - return json.dumps({"vertices": context["vertices"], "edges": context["edges"]}, ensure_ascii=False, indent=2) + return json.dumps( + {"vertices": context["vertices"], "edges": context["edges"]}, + ensure_ascii=False, + indent=2, + ) + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e) + raise gr.Error(str(e)) + + +def extract_graph(input_file, input_text, schema, example_prompt) -> str: + texts = read_documents(input_file, input_text) + scheduler = SchedulerSingleton.get_instance() + if not schema: + return "ERROR: please input with correct schema/format." + + try: + return scheduler.schedule_flow( + "graph_extract", schema, texts, example_prompt, "property_graph" + ) except Exception as e: # pylint: disable=broad-exception-caught log.error(e) raise gr.Error(str(e)) def update_vid_embedding(): - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) builder.fetch_graph_data().build_vertex_id_semantic_index() log.debug("Operators: %s", builder.operators) try: @@ -132,7 +169,9 @@ def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: data_json = json.loads(data.strip()) log.debug("Import graph data: %s", data) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) if schema: error_message = parse_schema(schema, builder) if error_message: @@ -154,7 +193,7 @@ def build_schema(input_text, query_example, few_shot): context = { "raw_texts": [input_text] if input_text else [], "query_examples": [], - "few_shot_schema": {} + "few_shot_schema": {}, } if few_shot: @@ -170,7 +209,7 @@ def build_schema(input_text, query_example, few_shot): context["query_examples"] = [ { "description": ex.get("description", ""), - "gremlin": ex.get("gremlin", "") + "gremlin": ex.get("gremlin", ""), } for ex in parsed_examples if isinstance(ex, dict) and "description" in ex and "gremlin" in ex @@ -178,7 +217,9 @@ def build_schema(input_text, query_example, few_shot): except json.JSONDecodeError as e: raise gr.Error(f"Query Examples is not in a valid JSON format: {e}") from e - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) try: schema = builder.build_schema().run(context) except Exception as e: diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py index 62bcdd9cb..138b0d359 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py @@ -23,11 +23,12 @@ from hugegraph_llm.config import resource_path, huge_settings, llm_settings from hugegraph_llm.indices.vector_index import VectorIndex -from hugegraph_llm.models.embeddings.init_embedding import Embeddings -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.kg_construction_task import KgBuilder -from hugegraph_llm.utils.embedding_utils import get_filename_prefix, get_index_folder_name -from hugegraph_llm.utils.hugegraph_utils import get_hg_client +from hugegraph_llm.models.embeddings.init_embedding import model_map +from hugegraph_llm.flows.scheduler import SchedulerSingleton +from hugegraph_llm.utils.embedding_utils import ( + get_filename_prefix, + get_index_folder_name, +) def read_documents(input_file, input_text): @@ -49,7 +50,9 @@ def read_documents(input_file, input_text): texts.append(text) elif full_path.endswith(".pdf"): # TODO: support PDF file - raise gr.Error("PDF will be supported later! Try to upload text/docx now") + raise gr.Error( + "PDF will be supported later! Try to upload text/docx now" + ) else: raise gr.Error("Please input txt or docx file.") else: @@ -59,33 +62,44 @@ def read_documents(input_file, input_text): # pylint: disable=C0301 def get_vector_index_info(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) - filename_prefix = get_filename_prefix(llm_settings.embedding_type, - getattr(Embeddings().get_embedding(), "model_name", None)) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, model_map.get(llm_settings.embedding_type) + ) chunk_vector_index = VectorIndex.from_index_file( str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix, - record_miss=False + record_miss=False, ) graph_vid_vector_index = VectorIndex.from_index_file( - str(os.path.join(resource_path, folder_name, "graph_vids")), - filename_prefix + str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix + ) + return json.dumps( + { + "embed_dim": chunk_vector_index.index.d, + "vector_info": { + "chunk_vector_num": chunk_vector_index.index.ntotal, + "graph_vid_vector_num": graph_vid_vector_index.index.ntotal, + "graph_properties_vector_num": len(chunk_vector_index.properties), + }, + }, + ensure_ascii=False, + indent=2, ) - return json.dumps({ - "embed_dim": chunk_vector_index.index.d, - "vector_info": { - "chunk_vector_num": chunk_vector_index.index.ntotal, - "graph_vid_vector_num": graph_vid_vector_index.index.ntotal, - "graph_properties_vector_num": len(chunk_vector_index.properties) - } - }, ensure_ascii=False, indent=2) def clean_vector_index(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) - filename_prefix = get_filename_prefix(llm_settings.embedding_type, - getattr(Embeddings().get_embedding(), "model_name", None)) - VectorIndex.clean(str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, model_map.get(llm_settings.embedding_type) + ) + VectorIndex.clean( + str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix + ) gr.Info("Clean vector index successfully!") @@ -93,6 +107,5 @@ def build_vector_index(input_file, input_text): if input_file and input_text: raise gr.Error("Please only choose one between file and text.") texts = read_documents(input_file, input_text) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - context = builder.chunk_split(texts, "paragraph", "zh").build_vector_index().run() - return json.dumps(context, ensure_ascii=False, indent=2) + scheduler = SchedulerSingleton.get_instance() + return scheduler.schedule_flow("build_vector_index", texts) From 2f01de574fe91406539bf054fab0a4b199c07d14 Mon Sep 17 00:00:00 2001 From: Linyu <94553312+weijinglin@users.noreply.github.com> Date: Thu, 25 Sep 2025 15:09:38 +0800 Subject: [PATCH 2/5] refactor: refactor hugegraph-ai to integrate with CGraph & port some usecases in web demo (#49) --- .../spec/hugegraph-llm/fixed_flow/design.md | 643 ++++++++++++++++++ .../hugegraph-llm/fixed_flow/requirements.md | 24 + .../spec/hugegraph-llm/fixed_flow/tasks.md | 36 + .../demo/rag_demo/vector_graph_block.py | 21 +- .../src/hugegraph_llm/flows/build_schema.py | 71 ++ .../hugegraph_llm/flows/build_vector_index.py | 4 +- .../flows/get_graph_index_info.py | 68 ++ .../src/hugegraph_llm/flows/graph_extract.py | 58 +- .../hugegraph_llm/flows/import_graph_data.py | 65 ++ .../hugegraph_llm/flows/prompt_generate.py | 63 ++ .../src/hugegraph_llm/flows/scheduler.py | 31 +- .../flows/update_vid_embeddings.py | 47 ++ .../src/hugegraph_llm/flows/utils.py | 34 + .../src/hugegraph_llm/nodes/base_node.py | 71 ++ .../nodes/document_node/chunk_split.py | 43 ++ .../hugegraph_node/commit_to_hugegraph.py | 35 + .../nodes/hugegraph_node/fetch_graph_data.py | 33 + .../nodes/hugegraph_node/schema.py | 74 ++ .../nodes/index_node/build_semantic_index.py | 34 + .../nodes/index_node/build_vector_index.py | 34 + .../nodes/llm_node/extract_info.py | 52 ++ .../nodes/llm_node/prompt_generate.py | 59 ++ .../nodes/llm_node/schema_build.py | 91 +++ hugegraph-llm/src/hugegraph_llm/nodes/util.py | 27 + .../operators/common_op/check_schema.py | 160 ----- .../operators/document_op/chunk_split.py | 58 -- .../hugegraph_op/commit_to_hugegraph.py | 127 +++- .../operators/hugegraph_op/schema_manager.py | 75 -- .../operators/index_op/build_vector_index.py | 48 -- .../operators/llm_op/info_extract.py | 146 ---- .../llm_op/property_graph_extract.py | 127 +--- .../src/hugegraph_llm/state/ai_state.py | 28 + .../hugegraph_llm/utils/graph_index_utils.py | 40 ++ 33 files changed, 1811 insertions(+), 716 deletions(-) create mode 100644 .vibedev/spec/hugegraph-llm/fixed_flow/design.md create mode 100644 .vibedev/spec/hugegraph-llm/fixed_flow/requirements.md create mode 100644 .vibedev/spec/hugegraph-llm/fixed_flow/tasks.md create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/build_schema.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/utils.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/base_node.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/util.py diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/design.md b/.vibedev/spec/hugegraph-llm/fixed_flow/design.md new file mode 100644 index 000000000..c5777236d --- /dev/null +++ b/.vibedev/spec/hugegraph-llm/fixed_flow/design.md @@ -0,0 +1,643 @@ +# Hugegraph-ai 固定工作流执行引擎设计文档 + +## 概述 + +Hugegraph固定工作流执行引擎是用来执行固定工作流的工作流执行引擎,每个工作流对应到实际Web Demo的一个具体用例,包括向量索引的构建,图索引的构建等等。该引擎基于PyCGraph框架构建,提供了高性能、可复用的流水线调度能力。 + +### 设计目标 + +- **性能优异**:通过流水线复用机制保证固定工作流的执行性能 +- **高可靠性**:确保数据一致性和故障恢复能力,提供完善的错误处理机制 +- **易于扩展**:能够简单轻松地新增固定工作流,支持动态调度 +- **资源优化**:通过流水线池化管理,减少重复构图开销 + +### 技术栈 + +- **PyCGraph**:基于C++的高性能图计算框架,提供GPipeline和GPipelineManager +- **Python**:主要开发语言,提供业务逻辑和接口层 +- **Threading**:支持并发调度和线程安全 + +### 模块分层 +```text +hugegraph-llm/ +└── src/ + └── hugegraph_llm/ + ├── api/ # FastAPI 接口层,提供 rag_api、admin_api 等服务 + ├── config/ # 配置管理,包含各类配置与生成工具 + ├── demo/ # Gradio Web Demo 及相关交互应用 + ├── document/ # 文档处理与分块等工具 + ├── enums/ # 枚举类型定义 + ├── flows/ # 工作流调度与核心流程(如向量/图索引构建、数据导入等) + │ ├── __init__.py + │ ├── common.py # BaseFlow抽象基类 + │ ├── scheduler.py # 调度器核心实现 + │ ├── build_vector_index.py # 向量索引构建工作流 + │ ├── graph_extract.py # 图抽取工作流 + │ ├── import_graph_data.py # 图数据导入工作流 + │ ├── update_vid_embeddings.py # 向量更新工作流 + │ ├── get_graph_index_info.py # 图索引信息获取工作流 + │ ├── build_schema.py # 模式构建工作流 + │ └── prompt_generate.py # 提示词生成工作流 + ├── indices/ # 各类索引实现(向量、图、关键词等) + ├── middleware/ # 中间件与请求处理 + ├── models/ # LLM、Embedding、Reranker 等模型相关 + ├── nodes/ # Node调度层,负责Operator生命周期和上下文管理 + │ ├── base_node.py + │ ├── document_node/ + │ ├── hugegraph_node/ + │ ├── index_node/ + │ ├── llm_node/ + │ └── util.py + ├── operators/ # 主要算子与任务(如 KG 构建、GraphRAG、Text2Gremlin 等) + ├── resources/ # 资源文件(Prompt、示例、Gremlin 模板等) + ├── state/ # 状态管理 + ├── utils/ # 工具类与通用方法 + └── __init__.py # 包初始化 +``` + +## 架构设计 + +### 整体架构 + +> 新架构在Flow与Operator之间引入Node层,Node负责Operator的生命周期管理、上下文绑定、参数区解耦和并发安全,所有Flow均通过Node组装,Operator只关注业务实现。 + +#### 架构图 + +```mermaid +graph TB + subgraph UserLayer["用户层"] + User["用户请求"] + end + + subgraph SchedulerLayer["调度层"] + Scheduler["Scheduler
调度器"] + Singleton["SchedulerSingleton
单例管理器"] + end + + subgraph FlowLayer["工作流层"] + Pool["pipeline_pool
流水线池"] + BVI["BuildVectorIndexFlow
向量索引构建"] + GE["GraphExtractFlow
图抽取工作流"] + end + + subgraph PyCGraphLayer["PyCGraph层"] + Manager1["GPipelineManager
向量索引管理器"] + Manager2["GPipelineManager
图抽取管理器"] + Pipeline1["GPipeline
向量索引流水线"] + Pipeline2["GPipeline
图抽取流水线"] + end + + subgraph OperatorLayer["算子层"] + ChunkSplit["ChunkSplitNode
文档分块"] + BuildVector["BuildVectorIndexNode
向量索引构建"] + SchemaNode["SchemaNode
模式管理"] + InfoExtract["ExtractNode
信息抽取"] + PropGraph["Commit2GraphNode
图数据导入"] + FetchNode["FetchGraphDataNode
图数据拉取"] + SemanticIndex["BuildSemanticIndexNode
语义索引构建"] + end + + subgraph StateLayer["状态层"] + WkInput["wkflow_input
工作流输入"] + WkState["wkflow_state
工作流状态"] + end + + User --> Scheduler + Scheduler --> Singleton + Scheduler --> Pool + Pool --> BVI + Pool --> GE + BVI --> Manager1 + GE --> Manager2 + Manager1 --> Pipeline1 + Manager2 --> Pipeline2 + Pipeline1 --> ChunkSplit + Pipeline1 --> BuildVector + Pipeline2 --> SchemaNode + Pipeline2 --> ChunkSplit + Pipeline2 --> InfoExtract + Pipeline2 --> PropGraph + Pipeline1 --> WkInput + Pipeline1 --> WkState + Pipeline2 --> WkInput + Pipeline2 --> WkState + + style Scheduler fill:#e1f5fe + style Pool fill:#f3e5f5 + style Manager1 fill:#fff3e0 + style Manager2 fill:#fff3e0 + style Pipeline1 fill:#e8f5e8 + style Pipeline2 fill:#e8f5e8 +``` + +#### 调度流程图 + +```mermaid +flowchart TD + Start([开始]) --> CheckFlow{检查工作流
是否支持} + CheckFlow -->|否| Error1[抛出ValueError] + CheckFlow -->|是| FetchPipeline[从Manager获取
可复用Pipeline] + + FetchPipeline --> IsNull{Pipeline
是否为null} + + IsNull -->|是| BuildNew[构建新Pipeline] + BuildNew --> InitPipeline[初始化Pipeline] + InitPipeline --> InitCheck{初始化
是否成功} + InitCheck -->|否| Error2[记录错误并中止] + InitCheck -->|是| RunPipeline[执行Pipeline] + RunPipeline --> RunCheck{执行
是否成功} + RunCheck -->|否| Error3[记录错误并中止] + RunCheck -->|是| PostDeal[后处理结果] + PostDeal --> AddToPool[添加到复用池] + AddToPool --> Return[返回结果] + + IsNull -->|否| PrepareInput[准备输入数据] + PrepareInput --> RunReused[执行复用Pipeline] + RunReused --> ReusedCheck{执行
是否成功} + ReusedCheck -->|否| Error4[抛出RuntimeError] + ReusedCheck -->|是| PostDealReused[后处理结果] + PostDealReused --> ReleasePipeline[释放Pipeline] + ReleasePipeline --> Return + + Error1 --> End([结束]) + Error2 --> End + Error3 --> End + Error4 --> End + Return --> End + + style Start fill:#4caf50 + style End fill:#f44336 + style CheckFlow fill:#ff9800 + style IsNull fill:#ff9800 + style InitCheck fill:#ff9800 + style RunCheck fill:#ff9800 + style ReusedCheck fill:#ff9800 +``` + +### 核心组件 + +#### 1. Scheduler(调度器) +- **职责**:调度中心,维护 `pipeline_pool`,提供统一的工作流调度接口 +- **特性**: + - 支持多种工作流类型(build_vector_index、graph_extract、import_graph_data、update_vid_embeddings、get_graph_index_info、build_schema、prompt_generate等) + - 流水线池化管理,支持复用 + - 线程安全的单例模式 + - 可配置的最大流水线数量 + +#### 2. GPipelineManager(流水线管理器) +- **来源**:PyCGraph框架提供 +- **职责**:负责流水线对象 `GPipeline` 的获取、添加、释放与复用 +- **特性**: + - 自动管理流水线生命周期 + - 支持流水线复用和资源回收 + - 提供fetch/add/release操作接口 + +#### 3. BaseFlow(工作流基类) +- **职责**:工作流构建与前后处理抽象 +- **接口**: + - `prepare()`: 预处理接口,准备输入数据 + - `build_flow()`: 组装Node并注册依赖关系 + - `post_deal()`: 后处理接口,处理执行结果 +- **实现**: + - `BuildVectorIndexFlow`: 向量索引构建工作流 + - `GraphExtractFlow`: 图抽取工作流 + - `ImportGraphDataFlow`: 图数据导入工作流 + - `UpdateVidEmbeddingsFlows`: 向量更新工作流 + - `GetGraphIndexInfoFlow`: 图索引信息获取工作流 + - `BuildSchemaFlow`: 模式构建工作流 + - `PromptGenerateFlow`: 提示词生成工作流 + +#### 4. Node(节点调度器) +- **职责**:作为Operator的生命周期管理者,负责参数区绑定、上下文初始化、并发安全、异常处理等。 +- **特性**: + - 统一生命周期接口(init、node_init、run、operator_schedule) + - 通过参数区(wkflow_input/wkflow_state)与Flow/Operator解耦 + - Operator只需实现run(data_json)方法,Node负责调度和结果写回 + - 典型Node如:ChunkSplitNode、BuildVectorIndexNode、SchemaNode、ExtractNode、Commit2GraphNode、FetchGraphDataNode、BuildSemanticIndexNode、SchemaBuildNode、PromptGenerateNode等 + +#### 5. Operator(算子) +- **职责**:实现具体的业务原子操作 +- **特性**: + - 只需关注自身业务逻辑实现 + - 由Node统一调度 + +#### 6. GPipeline(流水线实例) +- **来源**:PyCGraph框架提供 +- **职责**:具体流水线实例,包含参数区与节点DAG拓扑 +- **参数区**: + - `wkflow_input`: 流水线运行输入 + - `wkflow_state`: 流水线运行状态与中间结果 + +### 核心数据结构 + +```python +# Scheduler核心数据结构 +Scheduler.pipeline_pool: Dict[str, Any] = { + "build_vector_index": { + "manager": GPipelineManager(), + "flow": BuildVectorIndexFlow(), + }, + "graph_extract": { + "manager": GPipelineManager(), + "flow": GraphExtractFlow(), + } +} +``` + +### 调度流程 + +#### schedule_flow方法执行流程 + +1. **工作流验证**:校验 `flow` 是否受支持,查表获取对应的 `manager` 与 `flow` 实例 + +2. **流水线获取**:从 `manager.fetch()` 获取可复用的 `GPipeline` + +3. **新流水线处理**(当fetch()返回None时): + - 调用 `flow.build_flow(*args, **kwargs)` 构建新流水线 + - 调用 `pipeline.init()` 完成初始化,失败则记录错误并中止 + - 调用 `pipeline.run()` 执行,失败则中止 + - 调用 `flow.post_deal(pipeline)` 生成输出 + - 调用 `manager.add(pipeline)` 将流水线加入可复用池 + +4. **复用流水线处理**(当fetch()返回现有流水线时): + - 从 `pipeline.getGParamWithNoEmpty("wkflow_input")` 获取输入对象 + - 调用 `flow.prepare(prepared_input, *args, **kwargs)` 进行参数刷新 + - 调用 `pipeline.run()` 执行,失败则中止 + - 调用 `flow.post_deal(pipeline)` 生成输出 + - 调用 `manager.release(pipeline)` 归还流水线 + +### 并发与复用策略 + +#### 线程安全 +- `SchedulerSingleton` 使用双重检查锁保证全局单例 +- 线程安全获取 `Scheduler` 实例 + +#### 资源管理 +- 每种 `flow` 拥有独立的 `GPipelineManager` +- 最大并发量由 `Scheduler.max_pipeline` 与底层 `GPipelineManager` 策略共同约束 +- 通过 `fetch/add/release` 机制减少重复构图的开销 + +#### 性能优化 +- 流水线复用机制适合高频相同工作流场景 +- 减少重复初始化和构图的时间开销 +- 支持并发执行多个工作流实例 + +### 错误处理与日志 + +#### 错误检测 +- 对 `init/run` 的 `Status.isErr()` 进行检测 +- 统一抛出 `RuntimeError` 并记录详细 `status.getInfo()` +- 提供完整的错误堆栈信息 + +#### 日志记录 +- 使用统一的日志系统记录关键操作 +- 记录流水线执行状态和错误信息 +- 支持不同级别的日志输出 + +#### 结果处理 +- `flow.post_deal` 负责将 `wkflow_state` 转换为对外可消费结果(如JSON) +- 提供标准化的输出格式 +- 支持错误信息的友好展示 + +### 扩展指引 + +#### 新增Node/Operator/Flow步骤 +1. 实现Operator业务逻辑(如ChunkSplit/BuildVectorIndex/InfoExtract等) +2. 实现对应Node(继承BaseNode,负责参数区绑定和调度Operator) +3. 在Flow中组装Node,注册依赖关系 +4. 在Scheduler注册新的Flow + +#### 输入输出约定 +- 统一使用 `wkflow_input` 作为输入载体 +- 统一使用 `wkflow_state` 作为状态与结果容器 +- 确保可复用流水线在不同请求间可被快速重置 + +#### 最佳实践 +- 保持Flow类的无状态设计 +- 合理使用流水线复用机制 +- 提供完善的错误处理和日志记录 +- 遵循统一的接口规范 + +## Flow对象设计 + +### BaseFlow抽象基类 + +```python +class BaseFlow(ABC): + """ + Base class for flows, defines three interface methods: prepare, build_flow, and post_deal. + """ + + @abstractmethod + def prepare(self, prepared_input: WkFlowInput, *args, **kwargs): + """ + Pre-processing interface. + """ + pass + + @abstractmethod + def build_flow(self, *args, **kwargs): + """ + Interface for building the flow. + """ + pass + + @abstractmethod + def post_deal(self, *args, **kwargs): + """ + Post-processing interface. + """ + pass +``` + +### 接口说明 + +每个Flow对象都需要实现三个核心接口: + +- **prepare**: 用来准备整个workflow的输入数据,设置工作流参数 +- **build_flow**: 用来构建整个workflow的流水线,注册节点和依赖关系 +- **post_deal**: 用来处理workflow的执行结果,转换为对外输出格式 + +### 具体实现示例 + +#### BuildVectorIndexFlow(向量索引构建工作流) + +```python +class BuildVectorIndexFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, texts): + prepared_input.texts = texts + prepared_input.language = "zh" + prepared_input.split_type = "paragraph" + return + + def build_flow(self, texts): + pipeline = GPipeline() + # prepare for workflow input + prepared_input = WkFlowInput() + self.prepare(prepared_input, texts) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + chunk_split_node = ChunkSplitNode() + build_vector_node = BuildVectorIndexNode() + pipeline.registerGElement(chunk_split_node, set(), "chunk_split") + pipeline.registerGElement(build_vector_node, {chunk_split_node}, "build_vector") + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + return json.dumps(res, ensure_ascii=False, indent=2) +``` + +#### GraphExtractFlow(图抽取工作流) + +```python +class GraphExtractFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type): + prepared_input.texts = texts + prepared_input.language = "zh" + prepared_input.split_type = "document" + prepared_input.example_prompt = example_prompt + prepared_input.schema = schema + prepare_schema(prepared_input, schema) + return + + def build_flow(self, schema, texts, example_prompt, extract_type): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, schema, texts, example_prompt, extract_type) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + schema_node = SchemaNode() + + chunk_split_node = ChunkSplitNode() + graph_extract_node = ExtractNode() + + pipeline.registerGElement(schema_node, set(), "schema_node") + pipeline.registerGElement(chunk_split_node, set(), "chunk_split") + pipeline.registerGElement( + graph_extract_node, {schema_node, chunk_split_node}, "graph_extract" + ) + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + vertices = res.get("vertices", []) + edges = res.get("edges", []) + if not vertices and not edges: + log.info("Please check the schema.(The schema may not match the Doc)") + return json.dumps( + { + "vertices": vertices, + "edges": edges, + "warning": "The schema may not match the Doc", + }, + ensure_ascii=False, + indent=2, + ) + return json.dumps( + {"vertices": vertices, "edges": edges}, + ensure_ascii=False, + indent=2, + ) +``` + +## Node对象设计 + +### 节点生命周期 + +节点以 GNode 为抽象基类,统一生命周期与状态返回。方法职责与约定如下: + +#### 初始化阶段 + +- **init()**: + - **责任**:完成节点级初始化工作(如绑定共享上下文、准备参数区),确保节点具备运行所需的最小环境 + - **约定**:仅做轻量初始化,不执行业务逻辑;返回状态用于判断是否可继续 + +- **node_init()**: + - **责任**:解析与校验本次运行所需的输入(通常来自 wk_input),构建运行期依赖(如内部配置、变换器、资源句柄) + - **约定**:输入缺失或不合法时,应返回错误状态并中止后续执行;不产生对外可见的业务结果 + +#### 运行阶段 + +- **run()**: + - **责任**:执行业务主流程(纯计算或 I/O),在完成后将节点产出写入共享状态(wkflow_state/上下文) + - **约定**: + - 进入前应先调用 node_init() 并检查其返回状态 + - 对共享状态的写操作需遵循并发安全约定(如加锁/解锁) + - 出错使用统一状态返回,不抛出未捕获异常到流程编排层 + +### 输入/输出与上下文约定 + +- **输入**:通过编排层预置于参数区(如 wk_input),节点在 node_init() 中读取并校验 +- **输出**:通过共享状态容器(如 wkflow_state/上下文)对外暴露,键/字段命名应稳定可预期,供下游节点消费 + +### 错误处理约定 + +- 统一以状态对象表示成功/失败与信息;错误应尽早返回,避免在 run() 中继续副作用操作 +- 对可预见的校验类错误使用明确的错误信息,便于定位问题与编排层记录 + +### 并发与可重入约定 + +- 共享状态的写入需在临界区内完成;读取视数据一致性要求决定是否加锁 +- 节点应尽量保持无副作用或将副作用范围收敛在可控区域,以支持重试与复用 + +### 可测试性与解耦 + +- 业务纯逻辑应与框架交互解耦,优先封装为可单测的纯函数/内部方法 +- 节点仅负责生命周期编排与上下文读写,具体策略与算法通过内部可替换组件提供 + +### 节点类型 + +#### 文档处理节点 +- **ChunkSplitNode**: 文档分块处理节点 + - 功能:将输入文档按照指定策略进行分块 + - 输入:原始文档文本 + - 输出:分块后的文档片段 + +#### 索引构建节点 +- **BuildVectorIndexNode**: 向量索引构建节点 + - 功能:基于文档分块构建向量索引 + - 输入:文档分块 + - 输出:向量索引数据 + +#### 模式管理节点 +- **SchemaManagerNode**: 图模式管理节点 + - 功能:从HugeGraph获取图模式信息 + - 输入:图名称 + - 输出:图模式定义 + +- **CheckSchemaNode**: 模式校验节点 + - 功能:校验用户定义的图模式 + - 输入:用户定义的JSON模式 + - 输出:校验后的模式定义 + +#### 图抽取节点 +- **InfoExtractNode**: 信息抽取节点 + - 功能:从文档中抽取三元组信息 + - 输入:文档分块和模式定义 + - 输出:抽取的三元组数据 + +- **PropertyGraphExtractNode**: 属性图抽取节点 + - 功能:从文档中抽取属性图结构 + - 输入:文档分块和模式定义 + - 输出:抽取的顶点和边数据 + +#### 模式构建节点 +- **SchemaBuildNode**: 模式构建节点 + - 功能:基于文档和查询示例构建图模式 + - 输入:文档文本、查询示例、少样本模式 + - 输出:构建的图模式定义 + +#### 提示词生成节点 +- **PromptGenerateNode**: 提示词生成节点 + - 功能:基于源文本、场景和示例名称生成提示词 + - 输入:源文本、场景、示例名称 + - 输出:生成的提示词 + + +## 测试策略 + +### 测试目标 + +目前的测试策略主要目标是保证移植之后的workflow和移植之前的workflow执行结果、程序行为一致。 + +### 测试范围 + +#### 1. 功能测试 +- **工作流执行结果一致性**:确保新架构下的工作流执行结果与原有实现完全一致 +- **输入输出格式验证**:验证输入参数处理和输出格式转换的正确性 +- **错误处理测试**:确保错误场景下的行为与预期一致 + +#### 2. 性能测试 +- **流水线复用效果**:验证流水线复用机制的性能提升效果 +- **并发执行测试**:测试多工作流并发执行的稳定性和性能 +- **资源使用测试**:监控内存和CPU使用情况,确保资源使用合理 + +#### 3. 稳定性测试 +- **长时间运行测试**:验证系统在长时间运行下的稳定性 +- **异常恢复测试**:测试系统在异常情况下的恢复能力 +- **内存泄漏测试**:确保流水线复用不会导致内存泄漏 + +### 测试方法 + +#### 1. 单元测试 +- 对每个Flow类进行单元测试 +- 对每个Node类进行单元测试 +- 对Scheduler调度逻辑进行测试 + +#### 2. 集成测试 +- 端到端工作流测试 +- 多工作流组合测试 +- 与外部系统集成测试 + +#### 3. 性能基准测试 +- 建立性能基准线 +- 对比新旧架构的性能差异 +- 监控关键性能指标 + +### 测试数据 + +#### 1. 标准测试数据集 +- 准备标准化的测试文档 +- 准备标准化的图模式定义 +- 准备标准化的期望输出结果 + +#### 2. 边界测试数据 +- 空输入测试 +- 大文件测试 +- 特殊字符测试 +- 异常格式测试 + +### 测试环境 + +#### 1. 开发环境测试 +- 本地开发环境的功能验证 +- 快速迭代测试 + +#### 2. 测试环境验证 +- 模拟生产环境的完整测试 +- 性能压力测试 + +#### 3. 生产环境验证 +- 灰度发布验证 +- 生产环境监控 + +### 测试自动化 + +#### 1. CI/CD集成 +- 自动化测试流程集成 +- 代码提交触发测试 +- 测试结果自动报告 + +#### 2. 回归测试 +- 定期执行回归测试 +- 确保新功能不影响现有功能 +- 性能回归检测 + +### 测试指标 + +#### 1. 功能指标 +- 测试覆盖率 > 90% +- 功能正确性 100% +- 错误处理覆盖率 > 95% + +#### 2. 性能指标 +- 响应时间提升 > 20% +- 吞吐量提升 > 30% +- 资源使用优化 > 15% + +#### 3. 稳定性指标 +- 系统可用性 > 99.9% +- 平均故障恢复时间 < 5分钟 +- 内存泄漏率 = 0% diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/requirements.md b/.vibedev/spec/hugegraph-llm/fixed_flow/requirements.md new file mode 100644 index 000000000..095027369 --- /dev/null +++ b/.vibedev/spec/hugegraph-llm/fixed_flow/requirements.md @@ -0,0 +1,24 @@ +## 需求列表 + +### 核心框架设计 + +**核心**:Scheduler类中的schedule_flow设计与实现 + +**验收标准**: +1.1. 核心框架尽可能复用资源,避免资源的重复分配和释放 +1.2. 应该保证正常的请求处理指标要求 +1.3. 应该能够配置框架整体使用的资源上限 + +### 固定工作流移植 + +**核心**:移植Web Demo中的所有用例 +2.1. 保证使用核心框架移植后的工作流的程序行为和移植之前保持一致即可 + +**已完成的工作流类型**: +- build_vector_index: 向量索引构建工作流 +- graph_extract: 图抽取工作流 +- import_graph_data: 图数据导入工作流 +- update_vid_embeddings: 向量更新工作流 +- get_graph_index_info: 图索引信息获取工作流 +- build_schema: 模式构建工作流 +- prompt_generate: 提示词生成工作流 diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/tasks.md b/.vibedev/spec/hugegraph-llm/fixed_flow/tasks.md new file mode 100644 index 000000000..a84aee2ff --- /dev/null +++ b/.vibedev/spec/hugegraph-llm/fixed_flow/tasks.md @@ -0,0 +1,36 @@ +# HugeGraph-ai 固定工作流框架设计和用例移植 + +本文档将 HugeGraph 固定工作流框架设计和用例移植转换为一系列可执行的编码任务。 + +## 1. schedule_flow设计与实现 + +- [x] **1.1 构建Scheduler框架1.0** + - 需要能够复用已经创建过的Pipeline(Pipeline Pooling) + - 使用CGraph(Graph-based engine)作为底层执行引擎 + - 不同Node之间松耦合 + +- [ ] **1.2 优化Scheduler框架资源配置** + - 支持用户配置底层线程池参数 + - 现有的workflow可能会根据输入有细小的变化,导致相同的用例得到不同的workflow,怎么解决这个问题呢? + - Node/Operator解耦,Node负责生命周期和上下文,Operator只关注业务逻辑 + - Flow只负责组装Node,所有业务逻辑下沉到Node/Operator + - Scheduler支持多类型Flow注册,注册方式更灵活 + +- [ ] **1.3 优化Scheduler框架资源使用** + - 根据负载控制每个PipelineManager管理的Pipeline数量,实现动态扩缩容 + - Node层支持参数区自动绑定和并发安全 + - Operator只需实现run(data_json)方法,Node负责调度和结果写回 + +## 2. 固定工作流用例移植 + +- [x] **2.1 build_vector_index workflow移植** +- [x] **2.2 graph_extract workflow移植** +- [x] **2.3 import_graph_data workflow移植** + - 基于Node/Operator机制实现import_graph_data工作流 +- [x] **2.4 update_vid_embeddings workflow移植** + - 基于Node/Operator机制实现update_vid_embeddings工作流 +- [x] **2.5 get_graph_index_info workflow移植** +- [x] **2.6 build_schema workflow移植** + - 基于Node/Operator机制实现build_schema工作流 +- [x] **2.7 prompt_generate workflow移植** + - 基于Node/Operator机制实现prompt_generate工作流 diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py index 9897f420f..4aa476942 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py @@ -26,8 +26,7 @@ from hugegraph_llm.config import huge_settings from hugegraph_llm.config import prompt from hugegraph_llm.config import resource_path -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.llm_op.prompt_generate import PromptGenerate +from hugegraph_llm.flows.scheduler import SchedulerSingleton from hugegraph_llm.utils.graph_index_utils import ( get_graph_index_info, clean_all_graph_index, @@ -61,7 +60,7 @@ def store_prompt(doc, schema, example_prompt): def generate_prompt_for_ui(source_text, scenario, example_name): """ - Handles the UI logic for generating a new prompt. It calls the PromptGenerate operator. + Handles the UI logic for generating a new prompt using the new workflow architecture. """ if not all([source_text, scenario, example_name]): gr.Warning( @@ -69,19 +68,13 @@ def generate_prompt_for_ui(source_text, scenario, example_name): ) return gr.update() try: - prompt_generator = PromptGenerate(llm=LLMs().get_chat_llm()) - context = { - "source_text": source_text, - "scenario": scenario, - "example_name": example_name, - } - result_context = prompt_generator.run(context) - # Presents the result of generating prompt - generated_prompt = result_context.get( - "generated_extract_prompt", "Generation failed. Please check the logs." + # using new architecture + scheduler = SchedulerSingleton.get_instance() + result = scheduler.schedule_flow( + "prompt_generate", source_text, scenario, example_name ) gr.Info("Prompt generated successfully!") - return generated_prompt + return result except Exception as e: log.error("Error generating Prompt: %s", e, exc_info=True) raise gr.Error(f"Error generating Prompt: {e}") from e diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py b/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py new file mode 100644 index 000000000..6bbcb8512 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.nodes.llm_node.schema_build import SchemaBuildNode +from hugegraph_llm.utils.log import log + +import json +from PyCGraph import GPipeline + + +class BuildSchemaFlow(BaseFlow): + def __init__(self): + pass + + def prepare( + self, + prepared_input: WkFlowInput, + texts=None, + query_examples=None, + few_shot_schema=None, + ): + prepared_input.texts = texts + # Optional fields packed into wk_input for SchemaBuildNode + # Keep raw values; node will parse if strings + prepared_input.query_examples = query_examples + prepared_input.few_shot_schema = few_shot_schema + return + + def build_flow(self, texts=None, query_examples=None, few_shot_schema=None): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare( + prepared_input, + texts=texts, + query_examples=query_examples, + few_shot_schema=few_shot_schema, + ) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + schema_build_node = SchemaBuildNode() + pipeline.registerGElement(schema_build_node, set(), "schema_build") + + return pipeline + + def post_deal(self, pipeline=None): + state_json = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + if "schema" not in state_json: + return "" + res = state_json["schema"] + try: + formatted_schema = json.dumps(res, ensure_ascii=False, indent=2) + return formatted_schema + except (TypeError, ValueError) as e: + log.error("Failed to format schema: %s", e) + return str(res) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py index f1ee8c1c4..9a07b5dba 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py @@ -14,13 +14,13 @@ # limitations under the License. from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.document_node.chunk_split import ChunkSplitNode +from hugegraph_llm.nodes.index_node.build_vector_index import BuildVectorIndexNode from hugegraph_llm.state.ai_state import WkFlowInput import json from PyCGraph import GPipeline -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode -from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndexNode from hugegraph_llm.state.ai_state import WkFlowState diff --git a/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py new file mode 100644 index 000000000..fa10d0199 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +from hugegraph_llm.config import huge_settings, llm_settings, resource_path +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.indices.vector_index import VectorIndex +from hugegraph_llm.models.embeddings.init_embedding import model_map +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.nodes.hugegraph_node.fetch_graph_data import FetchGraphDataNode +from PyCGraph import GPipeline +from hugegraph_llm.utils.embedding_utils import ( + get_filename_prefix, + get_index_folder_name, +) + + +class GetGraphIndexInfoFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, *args, **kwargs): + return + + def build_flow(self, *args, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, *args, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + fetch_node = FetchGraphDataNode() + pipeline.registerGElement(fetch_node, set(), "fetch_node") + return pipeline + + def post_deal(self, pipeline=None): + graph_summary_info = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) + index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) + filename_prefix = get_filename_prefix( + llm_settings.embedding_type, + model_map.get(llm_settings.embedding_type, None), + ) + try: + vector_index = VectorIndex.from_index_file(index_dir, filename_prefix) + except FileNotFoundError: + return json.dumps(graph_summary_info, ensure_ascii=False, indent=2) + graph_summary_info["vid_index"] = { + "embed_dim": vector_index.index.d, + "num_vectors": vector_index.index.ntotal, + "num_vids": len(vector_index.properties), + } + return json.dumps(graph_summary_info, ensure_ascii=False, indent=2) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py index f1a6c5f6f..1b0c98253 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py @@ -16,14 +16,10 @@ import json from PyCGraph import GPipeline from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.document_node.chunk_split import ChunkSplitNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.llm_node.extract_info import ExtractNode from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from hugegraph_llm.operators.common_op.check_schema import CheckSchemaNode -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplitNode -from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManagerNode -from hugegraph_llm.operators.llm_op.info_extract import InfoExtractNode -from hugegraph_llm.operators.llm_op.property_graph_extract import ( - PropertyGraphExtractNode, -) from hugegraph_llm.utils.log import log @@ -31,21 +27,6 @@ class GraphExtractFlow(BaseFlow): def __init__(self): pass - def _import_schema( - self, - from_hugegraph=None, - from_extraction=None, - from_user_defined=None, - ): - if from_hugegraph: - return SchemaManagerNode() - elif from_user_defined: - return CheckSchemaNode() - elif from_extraction: - raise NotImplementedError("Not implemented yet") - else: - raise ValueError("No input data / invalid schema type") - def prepare( self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type ): @@ -55,17 +36,7 @@ def prepare( prepared_input.split_type = "document" prepared_input.example_prompt = example_prompt prepared_input.schema = schema - schema = schema.strip() - if schema.startswith("{"): - try: - schema = json.loads(schema) - prepared_input.schema = schema - except json.JSONDecodeError as exc: - log.error("Invalid JSON format in schema. Please check it again.") - raise ValueError("Invalid JSON format in schema.") from exc - else: - log.info("Get schema '%s' from graphdb.", schema) - prepared_input.graph_name = schema + prepared_input.extract_type = extract_type return def build_flow(self, schema, texts, example_prompt, extract_type): @@ -76,27 +47,10 @@ def build_flow(self, schema, texts, example_prompt, extract_type): pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") - schema = schema.strip() - schema_node = None - if schema.startswith("{"): - try: - schema = json.loads(schema) - schema_node = self._import_schema(from_user_defined=schema) - except json.JSONDecodeError as exc: - log.error("Invalid JSON format in schema. Please check it again.") - raise ValueError("Invalid JSON format in schema.") from exc - else: - log.info("Get schema '%s' from graphdb.", schema) - schema_node = self._import_schema(from_hugegraph=schema) + schema_node = SchemaNode() chunk_split_node = ChunkSplitNode() - graph_extract_node = None - if extract_type == "triples": - graph_extract_node = InfoExtractNode() - elif extract_type == "property_graph": - graph_extract_node = PropertyGraphExtractNode() - else: - raise ValueError(f"Unsupported extract_type: {extract_type}") + graph_extract_node = ExtractNode() pipeline.registerGElement(schema_node, set(), "schema_node") pipeline.registerGElement(chunk_split_node, set(), "chunk_split") pipeline.registerGElement( diff --git a/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py new file mode 100644 index 000000000..5581ef107 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import gradio as gr +from PyCGraph import GPipeline +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.hugegraph_node.commit_to_hugegraph import Commit2GraphNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.utils.log import log + + +class ImportGraphDataFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, data, schema): + try: + data_json = json.loads(data.strip()) if isinstance(data, str) else data + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON for 'data': {e.msg}") from e + log.debug( + "Import graph data (truncated): %s", + (data[:512] + "...") + if isinstance(data, str) and len(data) > 512 + else (data if isinstance(data, str) else ""), + ) + prepared_input.data_json = data_json + prepared_input.schema = schema + return + + def build_flow(self, data, schema): + pipeline = GPipeline() + prepared_input = WkFlowInput() + # prepare input data + self.prepare(prepared_input, data, schema) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + schema_node = SchemaNode() + commit_node = Commit2GraphNode() + pipeline.registerGElement(schema_node, set(), "schema_node") + pipeline.registerGElement(commit_node, {schema_node}, "commit_node") + + return pipeline + + def post_deal(self, pipeline=None): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + gr.Info("Import graph data successfully!") + return json.dumps(res, ensure_ascii=False, indent=2) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py new file mode 100644 index 000000000..aece6bd61 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py @@ -0,0 +1,63 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.llm_node.prompt_generate import PromptGenerateNode +from hugegraph_llm.state.ai_state import WkFlowInput + +from PyCGraph import GPipeline + +from hugegraph_llm.state.ai_state import WkFlowState + + +class PromptGenerateFlow(BaseFlow): + def __init__(self): + pass + + def prepare(self, prepared_input: WkFlowInput, source_text, scenario, example_name): + """ + Prepare input data for PromptGenerate workflow + """ + prepared_input.source_text = source_text + prepared_input.scenario = scenario + prepared_input.example_name = example_name + return + + def build_flow(self, source_text, scenario, example_name): + """ + Build the PromptGenerate workflow + """ + pipeline = GPipeline() + # Prepare workflow input + prepared_input = WkFlowInput() + self.prepare(prepared_input, source_text, scenario, example_name) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create PromptGenerate node + prompt_generate_node = PromptGenerateNode() + pipeline.registerGElement(prompt_generate_node, set(), "prompt_generate") + + return pipeline + + def post_deal(self, pipeline=None): + """ + Process the execution result of PromptGenerate workflow + """ + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + return res.get( + "generated_extract_prompt", "Generation failed. Please check the logs." + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py index b096310db..559540ce3 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -15,10 +15,15 @@ import threading from typing import Dict, Any -from PyCGraph import GPipelineManager +from PyCGraph import GPipeline, GPipelineManager from hugegraph_llm.flows.build_vector_index import BuildVectorIndexFlow from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.flows.graph_extract import GraphExtractFlow +from hugegraph_llm.flows.import_graph_data import ImportGraphDataFlow +from hugegraph_llm.flows.update_vid_embeddings import UpdateVidEmbeddingsFlows +from hugegraph_llm.flows.get_graph_index_info import GetGraphIndexInfoFlow +from hugegraph_llm.flows.build_schema import BuildSchemaFlow +from hugegraph_llm.flows.prompt_generate import PromptGenerateFlow from hugegraph_llm.utils.log import log @@ -37,6 +42,26 @@ def __init__(self, max_pipeline: int = 10): "manager": GPipelineManager(), "flow": GraphExtractFlow(), } + self.pipeline_pool["import_graph_data"] = { + "manager": GPipelineManager(), + "flow": ImportGraphDataFlow(), + } + self.pipeline_pool["update_vid_embeddings"] = { + "manager": GPipelineManager(), + "flow": UpdateVidEmbeddingsFlows(), + } + self.pipeline_pool["get_graph_index_info"] = { + "manager": GPipelineManager(), + "flow": GetGraphIndexInfoFlow(), + } + self.pipeline_pool["build_schema"] = { + "manager": GPipelineManager(), + "flow": BuildSchemaFlow(), + } + self.pipeline_pool["prompt_generate"] = { + "manager": GPipelineManager(), + "flow": PromptGenerateFlow(), + } self.max_pipeline = max_pipeline # TODO: Implement Agentic Workflow @@ -46,9 +71,9 @@ def agentic_flow(self): def schedule_flow(self, flow: str, *args, **kwargs): if flow not in self.pipeline_pool: raise ValueError(f"Unsupported workflow {flow}") - manager = self.pipeline_pool[flow]["manager"] + manager: GPipelineManager = self.pipeline_pool[flow]["manager"] flow: BaseFlow = self.pipeline_pool[flow]["flow"] - pipeline = manager.fetch() + pipeline: GPipeline = manager.fetch() if pipeline is None: # call coresponding flow_func to create new workflow pipeline = flow.build_flow(*args, **kwargs) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py b/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py new file mode 100644 index 000000000..b3f0d9923 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus, GPipeline +from hugegraph_llm.flows.common import BaseFlow, WkFlowInput +from hugegraph_llm.nodes.hugegraph_node.fetch_graph_data import FetchGraphDataNode +from hugegraph_llm.nodes.index_node.build_semantic_index import BuildSemanticIndexNode +from hugegraph_llm.state.ai_state import WkFlowState + + +class UpdateVidEmbeddingsFlows(BaseFlow): + def prepare(self, prepared_input: WkFlowInput): + return CStatus() + + def build_flow(self): + pipeline = GPipeline() + prepared_input = WkFlowInput() + # prepare input data + self.prepare(prepared_input) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + fetch_node = FetchGraphDataNode() + build_node = BuildSemanticIndexNode() + pipeline.registerGElement(fetch_node, set(), "fetch_node") + pipeline.registerGElement(build_node, {fetch_node}, "build_node") + + return pipeline + + def post_deal(self, pipeline): + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + removed_num = res.get("removed_vid_vector_num", 0) + added_num = res.get("added_vid_vector_num", 0) + return f"Removed {removed_num} vectors, added {added_num} vectors." diff --git a/hugegraph-llm/src/hugegraph_llm/flows/utils.py b/hugegraph-llm/src/hugegraph_llm/flows/utils.py new file mode 100644 index 000000000..b4ba05c84 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/utils.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from hugegraph_llm.state.ai_state import WkFlowInput +from hugegraph_llm.utils.log import log + + +def prepare_schema(prepared_input: WkFlowInput, schema): + schema = schema.strip() + if schema.startswith("{"): + try: + schema = json.loads(schema) + prepared_input.schema = schema + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", schema) + prepared_input.graph_name = schema + return diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py new file mode 100644 index 000000000..0ea0675c0 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import GNode, CStatus +from hugegraph_llm.nodes.util import init_context +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class BaseNode(GNode): + context: WkFlowState = None + wk_input: WkFlowInput = None + + def init(self): + return init_context(self) + + def node_init(self): + """ + Node initialization method, can be overridden by subclasses. + Returns a CStatus object indicating whether initialization succeeded. + """ + return CStatus() + + def run(self): + """ + Main logic for node execution, can be overridden by subclasses. + Returns a CStatus object indicating whether execution succeeded. + """ + sts = self.node_init() + if sts.isErr(): + return sts + self.context.lock() + try: + data_json = self.context.to_json() + finally: + self.context.unlock() + + try: + res = self.operator_schedule(data_json) + except Exception as exc: + import traceback + + node_info = f"Node type: {type(self).__name__}, Node object: {self}" + err_msg = f"Node failed: {exc}\n{node_info}\n{traceback.format_exc()}" + return CStatus(-1, err_msg) + + self.context.lock() + try: + if isinstance(res, dict): + self.context.assign_from_json(res) + finally: + self.context.unlock() + return CStatus() + + def operator_schedule(self, data_json): + """ + Interface for scheduling the operator, can be overridden by subclasses. + Returns a CStatus object indicating whether scheduling succeeded. + """ + pass diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py new file mode 100644 index 000000000..4c5acbe97 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hugegraph_llm.nodes.base_node import BaseNode +from PyCGraph import CStatus +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class ChunkSplitNode(BaseNode): + chunk_split_op: ChunkSplit + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + if ( + self.wk_input.texts is None + or self.wk_input.language is None + or self.wk_input.split_type is None + ): + return CStatus(-1, "Error occurs when prepare for workflow input") + texts = self.wk_input.texts + language = self.wk_input.language + split_type = self.wk_input.split_type + if isinstance(texts, str): + texts = [texts] + self.chunk_split_op = ChunkSplit(texts, split_type, language) + return CStatus() + + def operator_schedule(self, data_json): + return self.chunk_split_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py new file mode 100644 index 000000000..b576e8170 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class Commit2GraphNode(BaseNode): + commit_to_graph_op: Commit2Graph + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + data_json = self.wk_input.data_json if self.wk_input.data_json else None + if data_json: + self.context.assign_from_json(data_json) + self.commit_to_graph_op = Commit2Graph() + return CStatus() + + def operator_schedule(self, data_json): + return self.commit_to_graph_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py new file mode 100644 index 000000000..b2434e524 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.utils.hugegraph_utils import get_hg_client + + +class FetchGraphDataNode(BaseNode): + fetch_graph_data_op: FetchGraphData + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + self.fetch_graph_data_op = FetchGraphData(get_hg_client()) + return CStatus() + + def operator_schedule(self, data_json): + return self.fetch_graph_data_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py new file mode 100644 index 000000000..71c490b20 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.common_op.check_schema import CheckSchema +from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.utils.log import log + + +class SchemaNode(BaseNode): + schema_manager: SchemaManager + check_schema: CheckSchema + context: WkFlowState = None + wk_input: WkFlowInput = None + + schema = None + + def _import_schema( + self, + from_hugegraph=None, + from_extraction=None, + from_user_defined=None, + ): + if from_hugegraph: + return SchemaManager(from_hugegraph) + elif from_user_defined: + return CheckSchema(from_user_defined) + elif from_extraction: + raise NotImplementedError("Not implemented yet") + else: + raise ValueError("No input data / invalid schema type") + + def node_init(self): + self.schema = self.wk_input.schema + self.schema = self.schema.strip() + if self.schema.startswith("{"): + try: + schema = json.loads(self.schema) + self.check_schema = self._import_schema(from_user_defined=schema) + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", self.schema) + self.schema_manager = self._import_schema(from_hugegraph=self.schema) + return CStatus() + + def operator_schedule(self, data_json): + print(f"check data json {data_json}") + if self.schema.startswith("{"): + try: + return self.check_schema.run(data_json) + except json.JSONDecodeError as exc: + log.error("Invalid JSON format in schema. Please check it again.") + raise ValueError("Invalid JSON format in schema.") from exc + else: + log.info("Get schema '%s' from graphdb.", self.schema) + return self.schema_manager.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py new file mode 100644 index 000000000..ab31fa394 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.build_semantic_index import BuildSemanticIndex +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class BuildSemanticIndexNode(BaseNode): + build_semantic_index_op: BuildSemanticIndex + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + self.build_semantic_index_op = BuildSemanticIndex(get_embedding(llm_settings)) + return CStatus() + + def operator_schedule(self, data_json): + return self.build_semantic_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py new file mode 100644 index 000000000..cf2f9b677 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndex +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class BuildVectorIndexNode(BaseNode): + build_vector_index_op: BuildVectorIndex + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + self.build_vector_index_op = BuildVectorIndex(get_embedding(llm_settings)) + return CStatus() + + def operator_schedule(self, data_json): + return self.build_vector_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py new file mode 100644 index 000000000..8bceed804 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.info_extract import InfoExtract +from hugegraph_llm.operators.llm_op.property_graph_extract import PropertyGraphExtract +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class ExtractNode(BaseNode): + property_graph_extract: PropertyGraphExtract + info_extract: InfoExtract + context: WkFlowState = None + wk_input: WkFlowInput = None + + extract_type: str = None + + def node_init(self): + llm = get_chat_llm(llm_settings) + if self.wk_input.example_prompt is None: + return CStatus(-1, "Error occurs when prepare for workflow input") + example_prompt = self.wk_input.example_prompt + extract_type = self.wk_input.extract_type + self.extract_type = extract_type + if extract_type == "triples": + self.info_extract = InfoExtract(llm, example_prompt) + elif extract_type == "property_graph": + self.property_graph_extract = PropertyGraphExtract(llm, example_prompt) + else: + return CStatus(-1, f"Unsupported extract_type: {extract_type}") + return CStatus() + + def operator_schedule(self, data_json): + if self.extract_type == "triples": + return self.info_extract.run(data_json) + elif self.extract_type == "property_graph": + return self.property_graph_extract.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py new file mode 100644 index 000000000..317f9e6ac --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.prompt_generate import PromptGenerate +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class PromptGenerateNode(BaseNode): + prompt_generate: PromptGenerate + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + """ + Node initialization method, initialize PromptGenerate operator + """ + llm = get_chat_llm(llm_settings) + if not all( + [ + self.wk_input.source_text, + self.wk_input.scenario, + self.wk_input.example_name, + ] + ): + return CStatus( + -1, + "Missing required parameters: source_text, scenario, or example_name", + ) + + self.prompt_generate = PromptGenerate(llm) + context = { + "source_text": self.wk_input.source_text, + "scenario": self.wk_input.scenario, + "example_name": self.wk_input.example_name, + } + self.context.assign_from_json(context) + return CStatus() + + def operator_schedule(self, data_json): + """ + Schedule the execution of PromptGenerate operator + """ + return self.prompt_generate.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py new file mode 100644 index 000000000..a28b41346 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.models.llms.init_llm import get_chat_llm +from hugegraph_llm.config import llm_settings +from hugegraph_llm.operators.llm_op.schema_build import SchemaBuilder +from hugegraph_llm.utils.log import log + + +class SchemaBuildNode(BaseNode): + schema_builder: SchemaBuilder + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + llm = get_chat_llm(llm_settings) + self.schema_builder = SchemaBuilder(llm) + + # texts -> raw_texts + raw_texts = [] + if self.wk_input.texts: + if isinstance(self.wk_input.texts, list): + raw_texts = [t for t in self.wk_input.texts if isinstance(t, str)] + elif isinstance(self.wk_input.texts, str): + raw_texts = [self.wk_input.texts] + + # query_examples: already parsed list[dict] or raw JSON string + query_examples = [] + qe_src = self.wk_input.query_examples if self.wk_input.query_examples else None + if qe_src: + try: + parsed_examples = json.loads(qe_src) + # Validate and retain the description and gremlin fields + query_examples = [ + { + "description": ex.get("description", ""), + "gremlin": ex.get("gremlin", ""), + } + for ex in parsed_examples + if isinstance(ex, dict) and "description" in ex and "gremlin" in ex + ] + except json.JSONDecodeError as e: + return CStatus(-1, f"Query Examples is not in a valid JSON format: {e}") + + # few_shot_schema: already parsed dict or raw JSON string + few_shot_schema = {} + fss_src = ( + self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None + ) + if fss_src: + try: + few_shot_schema = json.loads(fss_src) + except json.JSONDecodeError as e: + return CStatus( + -1, f"Few Shot Schema is not in a valid JSON format: {e}" + ) + + _context_payload = { + "raw_texts": raw_texts, + "query_examples": query_examples, + "few_shot_schema": few_shot_schema, + } + self.context.assign_from_json(_context_payload) + + return CStatus() + + def operator_schedule(self, data_json): + try: + schema_result = self.schema_builder.run(data_json) + + return {"schema": schema_result} + except Exception as e: + log.error("Failed to generate schema: %s", e) + return {"schema": f"Schema generation failed: {e}"} diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/util.py b/hugegraph-llm/src/hugegraph_llm/nodes/util.py new file mode 100644 index 000000000..60bdc2e86 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/util.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus + + +def init_context(obj) -> CStatus: + try: + obj.context = obj.getGParamWithNoEmpty("wkflow_state") + obj.wk_input = obj.getGParamWithNoEmpty("wkflow_input") + if obj.context is None or obj.wk_input is None: + return CStatus(-1, "Required workflow parameters not found") + return CStatus() + except Exception as e: + return CStatus(-1, f"Failed to initialize context: {str(e)}") diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index 7a533517a..c1c742032 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -20,12 +20,8 @@ from hugegraph_llm.enums.property_cardinality import PropertyCardinality from hugegraph_llm.enums.property_data_type import PropertyDataType -from hugegraph_llm.operators.util import init_context from hugegraph_llm.utils.log import log -from PyCGraph import GNode, CStatus -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState - def log_and_raise(message: str) -> None: log.warning(message) @@ -174,159 +170,3 @@ def _add_missing_properties( } ) property_label_set.add(prop) - - -class CheckSchemaNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - return init_context(self) - - def node_init(self): - if self.wk_input.schema is None: - return CStatus(-1, "Error occurs when prepare for workflow input") - self.data = self.wk_input.schema - return CStatus() - - def run(self) -> CStatus: - # init workflow input - sts = self.node_init() - if sts.isErr(): - return sts - # 1. Validate the schema structure - self.context.lock() - schema = self.data or self.context.schema - self._validate_schema(schema) - # 2. Process property labels and also create a set for it - property_labels, property_label_set = self._process_property_labels(schema) - # 3. Process properties in given vertex/edge labels - self._process_vertex_labels(schema, property_labels, property_label_set) - self._process_edge_labels(schema, property_labels, property_label_set) - # 4. Update schema with processed pks - schema["propertykeys"] = property_labels - self.context.schema = schema - self.context.unlock() - return CStatus() - - def _validate_schema(self, schema: Dict[str, Any]) -> None: - check_type(schema, dict, "Input data is not a dictionary.") - if "vertexlabels" not in schema or "edgelabels" not in schema: - log_and_raise("Input data does not contain 'vertexlabels' or 'edgelabels'.") - check_type( - schema["vertexlabels"], list, "'vertexlabels' in input data is not a list." - ) - check_type( - schema["edgelabels"], list, "'edgelabels' in input data is not a list." - ) - - def _process_property_labels(self, schema: Dict[str, Any]) -> (list, set): - property_labels = schema.get("propertykeys", []) - check_type( - property_labels, - list, - "'propertykeys' in input data is not of correct type.", - ) - property_label_set = {label["name"] for label in property_labels} - return property_labels, property_label_set - - def _process_vertex_labels( - self, schema: Dict[str, Any], property_labels: list, property_label_set: set - ) -> None: - for vertex_label in schema["vertexlabels"]: - self._validate_vertex_label(vertex_label) - properties = vertex_label["properties"] - primary_keys = self._process_keys( - vertex_label, "primary_keys", properties[:1] - ) - if len(primary_keys) == 0: - log_and_raise(f"'primary_keys' of {vertex_label['name']} is empty.") - vertex_label["primary_keys"] = primary_keys - nullable_keys = self._process_keys( - vertex_label, "nullable_keys", properties[1:] - ) - vertex_label["nullable_keys"] = nullable_keys - self._add_missing_properties( - properties, property_labels, property_label_set - ) - - def _process_edge_labels( - self, schema: Dict[str, Any], property_labels: list, property_label_set: set - ) -> None: - for edge_label in schema["edgelabels"]: - self._validate_edge_label(edge_label) - properties = edge_label.get("properties", []) - self._add_missing_properties( - properties, property_labels, property_label_set - ) - - def _validate_vertex_label(self, vertex_label: Dict[str, Any]) -> None: - check_type(vertex_label, dict, "VertexLabel in input data is not a dictionary.") - if "name" not in vertex_label: - log_and_raise("VertexLabel in input data does not contain 'name'.") - check_type( - vertex_label["name"], str, "'name' in vertex_label is not of correct type." - ) - if "properties" not in vertex_label: - log_and_raise("VertexLabel in input data does not contain 'properties'.") - check_type( - vertex_label["properties"], - list, - "'properties' in vertex_label is not of correct type.", - ) - if len(vertex_label["properties"]) == 0: - log_and_raise("'properties' in vertex_label is empty.") - - def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: - check_type(edge_label, dict, "EdgeLabel in input data is not a dictionary.") - if ( - "name" not in edge_label - or "source_label" not in edge_label - or "target_label" not in edge_label - ): - log_and_raise( - "EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'." - ) - check_type( - edge_label["name"], str, "'name' in edge_label is not of correct type." - ) - check_type( - edge_label["source_label"], - str, - "'source_label' in edge_label is not of correct type.", - ) - check_type( - edge_label["target_label"], - str, - "'target_label' in edge_label is not of correct type.", - ) - - def _process_keys( - self, label: Dict[str, Any], key_type: str, default_keys: list - ) -> list: - keys = label.get(key_type, default_keys) - check_type( - keys, list, f"'{key_type}' in {label['name']} is not of correct type." - ) - new_keys = [key for key in keys if key in label["properties"]] - return new_keys - - def _add_missing_properties( - self, properties: list, property_labels: list, property_label_set: set - ) -> None: - for prop in properties: - if prop not in property_label_set: - property_labels.append( - { - "name": prop, - "data_type": PropertyDataType.DEFAULT.value, - "cardinality": PropertyCardinality.DEFAULT.value, - } - ) - property_label_set.add(prop) - - def get_result(self): - self.context.lock() - res = self.context.to_json() - self.context.unlock() - return res diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py index d779a40ab..c31e77af7 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/chunk_split.py @@ -19,8 +19,6 @@ from typing import Literal, Dict, Any, Optional, Union, List from langchain_text_splitters import RecursiveCharacterTextSplitter -from hugegraph_llm.operators.util import init_context -from PyCGraph import GNode, CStatus # Constants LANGUAGE_ZH = "zh" @@ -30,62 +28,6 @@ SPLIT_TYPE_SENTENCE = "sentence" -class ChunkSplitNode(GNode): - def init(self): - return init_context(self) - - def node_init(self): - if ( - self.wk_input.texts is None - or self.wk_input.language is None - or self.wk_input.split_type is None - ): - return CStatus(-1, "Error occurs when prepare for workflow input") - texts = self.wk_input.texts - language = self.wk_input.language - split_type = self.wk_input.split_type - if isinstance(texts, str): - texts = [texts] - self.texts = texts - self.separators = self._get_separators(language) - self.text_splitter = self._get_text_splitter(split_type) - return CStatus() - - def _get_separators(self, language: str) -> List[str]: - if language == LANGUAGE_ZH: - return ["\n\n", "\n", "。", ",", ""] - if language == LANGUAGE_EN: - return ["\n\n", "\n", ".", ",", " ", ""] - raise ValueError("language must be zh or en") - - def _get_text_splitter(self, split_type: str): - if split_type == SPLIT_TYPE_DOCUMENT: - return lambda text: [text] - if split_type == SPLIT_TYPE_PARAGRAPH: - return RecursiveCharacterTextSplitter( - chunk_size=500, chunk_overlap=30, separators=self.separators - ).split_text - if split_type == SPLIT_TYPE_SENTENCE: - return RecursiveCharacterTextSplitter( - chunk_size=50, chunk_overlap=0, separators=self.separators - ).split_text - raise ValueError("Type must be document, paragraph or sentence") - - def run(self): - sts = self.node_init() - if sts.isErr(): - return sts - all_chunks = [] - for text in self.texts: - chunks = self.text_splitter(text) - all_chunks.extend(chunks) - - self.context.lock() - self.context.chunks = all_chunks - self.context.unlock() - return CStatus() - - class ChunkSplit: def __init__( self, diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py index 5cc846d21..9eec04f7f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py @@ -40,15 +40,19 @@ def run(self, data: dict) -> Dict[str, Any]: schema = data.get("schema") vertices = data.get("vertices", []) edges = data.get("edges", []) - + print(f"get schema {schema}") if not vertices and not edges: - log.critical("(Loading) Both vertices and edges are empty. Please check the input data again.") + log.critical( + "(Loading) Both vertices and edges are empty. Please check the input data again." + ) raise ValueError("Both vertices and edges input are empty.") if not schema: # TODO: ensure the function works correctly (update the logic later) self.schema_free_mode(data.get("triples", [])) - log.warning("Using schema_free mode, could try schema_define mode for better effect!") + log.warning( + "Using schema_free mode, could try schema_define mode for better effect!" + ) else: self.init_schema_if_need(schema) self.load_into_graph(vertices, edges, schema) @@ -64,7 +68,9 @@ def _set_default_property(self, key, input_properties, property_label_map): # list or set default_value = [] input_properties[key] = default_value - log.warning("Property '%s' missing in vertex, set to '%s' for now", key, default_value) + log.warning( + "Property '%s' missing in vertex, set to '%s' for now", key, default_value + ) def _handle_graph_creation(self, func, *args, **kwargs): try: @@ -78,29 +84,42 @@ def _handle_graph_creation(self, func, *args, **kwargs): def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many-statements # pylint: disable=R0912 (too-many-branches) - vertex_label_map = {v_label["name"]: v_label for v_label in schema["vertexlabels"]} + vertex_label_map = { + v_label["name"]: v_label for v_label in schema["vertexlabels"] + } edge_label_map = {e_label["name"]: e_label for e_label in schema["edgelabels"]} - property_label_map = {p_label["name"]: p_label for p_label in schema["propertykeys"]} + property_label_map = { + p_label["name"]: p_label for p_label in schema["propertykeys"] + } for vertex in vertices: input_label = vertex["label"] # 1. ensure the input_label in the graph schema if input_label not in vertex_label_map: - log.critical("(Input) VertexLabel %s not found in schema, skip & need check it!", input_label) + log.critical( + "(Input) VertexLabel %s not found in schema, skip & need check it!", + input_label, + ) continue input_properties = vertex["properties"] vertex_label = vertex_label_map[input_label] primary_keys = vertex_label["primary_keys"] nullable_keys = vertex_label.get("nullable_keys", []) - non_null_keys = [key for key in vertex_label["properties"] if key not in nullable_keys] + non_null_keys = [ + key for key in vertex_label["properties"] if key not in nullable_keys + ] has_problem = False # 2. Handle primary-keys mode vertex for pk in primary_keys: if not input_properties.get(pk): if len(primary_keys) == 1: - log.error("Primary-key '%s' missing in vertex %s, skip it & need check it again", pk, vertex) + log.error( + "Primary-key '%s' missing in vertex %s, skip it & need check it again", + pk, + vertex, + ) has_problem = True break # TODO: transform to Enum first (better in earlier step) @@ -110,14 +129,20 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- input_properties[pk] = default_value_map(data_type) else: input_properties[pk] = [] - log.warning("Primary-key '%s' missing in vertex %s, mark empty & need check it again!", pk, vertex) + log.warning( + "Primary-key '%s' missing in vertex %s, mark empty & need check it again!", + pk, + vertex, + ) if has_problem: continue # 3. Ensure all non-nullable props are set for key in non_null_keys: if key not in input_properties: - self._set_default_property(key, input_properties, property_label_map) + self._set_default_property( + key, input_properties, property_label_map + ) # 4. Check all data type value is right for key, value in input_properties.items(): @@ -125,14 +150,19 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- data_type = property_label_map[key]["data_type"] cardinality = property_label_map[key]["cardinality"] if not self._check_property_data_type(data_type, cardinality, value): - log.error("Property type/format '%s' is not correct, skip it & need check it again", key) + log.error( + "Property type/format '%s' is not correct, skip it & need check it again", + key, + ) has_problem = True break if has_problem: continue # TODO: we could try batch add vertices first, setback to single-mode if failed - vid = self._handle_graph_creation(self.client.graph().addVertex, input_label, input_properties).id + vid = self._handle_graph_creation( + self.client.graph().addVertex, input_label, input_properties + ).id vertex["id"] = vid for edge in edges: @@ -142,11 +172,16 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- properties = edge["properties"] if label not in edge_label_map: - log.critical("(Input) EdgeLabel %s not found in schema, skip & need check it!", label) + log.critical( + "(Input) EdgeLabel %s not found in schema, skip & need check it!", + label, + ) continue # TODO: we could try batch add edges first, setback to single-mode if failed - self._handle_graph_creation(self.client.graph().addEdge, label, start, end, properties) + self._handle_graph_creation( + self.client.graph().addEdge, label, start, end, properties + ) def init_schema_if_need(self, schema: dict): properties = schema["propertykeys"] @@ -170,19 +205,27 @@ def init_schema_if_need(self, schema: dict): source_vertex_label = edge["source_label"] target_vertex_label = edge["target_label"] properties = edge["properties"] - self.schema.edgeLabel(edge_label).sourceLabel(source_vertex_label).targetLabel( - target_vertex_label - ).properties(*properties).nullableKeys(*properties).ifNotExist().create() + self.schema.edgeLabel(edge_label).sourceLabel( + source_vertex_label + ).targetLabel(target_vertex_label).properties(*properties).nullableKeys( + *properties + ).ifNotExist().create() def schema_free_mode(self, data): self.schema.propertyKey("name").asText().ifNotExist().create() - self.schema.vertexLabel("vertex").useCustomizeStringId().properties("name").ifNotExist().create() - self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel("vertex").properties( + self.schema.vertexLabel("vertex").useCustomizeStringId().properties( "name" ).ifNotExist().create() + self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel( + "vertex" + ).properties("name").ifNotExist().create() - self.schema.indexLabel("vertexByName").onV("vertex").by("name").secondary().ifNotExist().create() - self.schema.indexLabel("edgeByName").onE("edge").by("name").secondary().ifNotExist().create() + self.schema.indexLabel("vertexByName").onV("vertex").by( + "name" + ).secondary().ifNotExist().create() + self.schema.indexLabel("edgeByName").onE("edge").by( + "name" + ).secondary().ifNotExist().create() for item in data: s, p, o = (element.strip() for element in item) @@ -196,8 +239,12 @@ def _create_property(self, prop: dict): data_type = PropertyDataType(prop["data_type"]) cardinality = PropertyCardinality(prop["cardinality"]) except ValueError: - log.critical("Invalid data type %s / cardinality %s for property %s, skip & should check it again", - prop["data_type"], prop["cardinality"], name) + log.critical( + "Invalid data type %s / cardinality %s for property %s, skip & should check it again", + prop["data_type"], + prop["cardinality"], + name, + ) return property_key = self.schema.propertyKey(name) @@ -231,7 +278,9 @@ def _set_property_data_type(self, property_key, data_type): log.warning("UUID type is not supported, use text instead") property_key.asText() else: - log.error("Unknown data type %s for property_key %s", data_type, property_key) + log.error( + "Unknown data type %s for property_key %s", data_type, property_key + ) def _set_property_cardinality(self, property_key, cardinality): if cardinality == PropertyCardinality.SINGLE: @@ -241,10 +290,17 @@ def _set_property_cardinality(self, property_key, cardinality): elif cardinality == PropertyCardinality.SET: property_key.valueSet() else: - log.error("Unknown cardinality %s for property_key %s", cardinality, property_key) - - def _check_property_data_type(self, data_type: str, cardinality: str, value) -> bool: - if cardinality in (PropertyCardinality.LIST.value, PropertyCardinality.SET.value): + log.error( + "Unknown cardinality %s for property_key %s", cardinality, property_key + ) + + def _check_property_data_type( + self, data_type: str, cardinality: str, value + ) -> bool: + if cardinality in ( + PropertyCardinality.LIST.value, + PropertyCardinality.SET.value, + ): return self._check_collection_data_type(data_type, value) return self._check_single_data_type(data_type, value) @@ -259,14 +315,21 @@ def _check_collection_data_type(self, data_type: str, value) -> bool: def _check_single_data_type(self, data_type: str, value) -> bool: if data_type == PropertyDataType.BOOLEAN.value: return isinstance(value, bool) - if data_type in (PropertyDataType.BYTE.value, PropertyDataType.INT.value, PropertyDataType.LONG.value): + if data_type in ( + PropertyDataType.BYTE.value, + PropertyDataType.INT.value, + PropertyDataType.LONG.value, + ): return isinstance(value, int) if data_type in (PropertyDataType.FLOAT.value, PropertyDataType.DOUBLE.value): return isinstance(value, float) if data_type in (PropertyDataType.TEXT.value, PropertyDataType.UUID.value): return isinstance(value, str) # TODO: check ok below - if data_type == PropertyDataType.DATE.value: # the format should be "yyyy-MM-dd" + if ( + data_type == PropertyDataType.DATE.value + ): # the format should be "yyyy-MM-dd" import re - return isinstance(value, str) and re.match(r'^\d{4}-\d{2}-\d{2}$', value) + + return isinstance(value, str) and re.match(r"^\d{4}-\d{2}-\d{2}$", value) raise ValueError(f"Unknown/Unsupported data type: {data_type}") diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py index 670c18b4a..c4e2124c3 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -17,12 +17,8 @@ from typing import Dict, Any, Optional from hugegraph_llm.config import huge_settings -from hugegraph_llm.operators.util import init_context -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from pyhugegraph.client import PyHugeClient -from PyCGraph import GNode, CStatus - class SchemaManager: def __init__(self, graph_name: str): @@ -74,74 +70,3 @@ def run(self, context: Optional[Dict[str, Any]]) -> Dict[str, Any]: # TODO: enhance the logic here context["simple_schema"] = self.simple_schema(schema) return context - - -class SchemaManagerNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - return init_context(self) - - def node_init(self): - if self.wk_input.graph_name is None: - return CStatus(-1, "Error occurs when prepare for workflow input") - graph_name = self.wk_input.graph_name - self.graph_name = graph_name - self.client = PyHugeClient( - url=huge_settings.graph_url, - graph=self.graph_name, - user=huge_settings.graph_user, - pwd=huge_settings.graph_pwd, - graphspace=huge_settings.graph_space, - ) - self.schema = self.client.schema() - return CStatus() - - def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: - mini_schema = {} - - # Add necessary vertexlabels items (3) - if "vertexlabels" in schema: - mini_schema["vertexlabels"] = [] - for vertex in schema["vertexlabels"]: - new_vertex = { - key: vertex[key] - for key in ["id", "name", "properties"] - if key in vertex - } - mini_schema["vertexlabels"].append(new_vertex) - - # Add necessary edgelabels items (4) - if "edgelabels" in schema: - mini_schema["edgelabels"] = [] - for edge in schema["edgelabels"]: - new_edge = { - key: edge[key] - for key in ["name", "source_label", "target_label", "properties"] - if key in edge - } - mini_schema["edgelabels"].append(new_edge) - - return mini_schema - - def run(self) -> CStatus: - sts = self.node_init() - if sts.isErr(): - return sts - schema = self.schema.getSchema() - if not schema["vertexlabels"] and not schema["edgelabels"]: - raise Exception(f"Can not get {self.graph_name}'s schema from HugeGraph!") - - self.context.lock() - self.context.schema = schema - # TODO: enhance the logic here - self.context.simple_schema = self.simple_schema(schema) - self.context.unlock() - return CStatus() - - def get_result(self): - self.context.lock() - res = self.context.to_json() - self.context.unlock() - return res diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py index ee89d330f..5cdad0316 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py @@ -30,54 +30,6 @@ ) from hugegraph_llm.utils.log import log -from hugegraph_llm.operators.util import init_context -from hugegraph_llm.models.embeddings.init_embedding import get_embedding -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from PyCGraph import GNode, CStatus - - -class BuildVectorIndexNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - return init_context(self) - - def node_init(self): - self.embedding = get_embedding(llm_settings) - self.folder_name = get_index_folder_name( - huge_settings.graph_name, huge_settings.graph_space - ) - self.index_dir = str(os.path.join(resource_path, self.folder_name, "chunks")) - self.filename_prefix = get_filename_prefix( - llm_settings.embedding_type, getattr(self.embedding, "model_name", None) - ) - self.vector_index = VectorIndex.from_index_file( - self.index_dir, self.filename_prefix - ) - return CStatus() - - def run(self): - # init workflow input - sts = self.node_init() - if sts.isErr(): - return sts - self.context.lock() - try: - if self.context.chunks is None: - raise ValueError("chunks not found in context.") - chunks = self.context.chunks - finally: - self.context.unlock() - chunks_embedding = [] - log.debug("Building vector index for %s chunks...", len(chunks)) - # TODO: use async_get_texts_embedding instead of single sync method - chunks_embedding = asyncio.run(get_embeddings_parallel(self.embedding, chunks)) - if len(chunks_embedding) > 0: - self.vector_index.add(chunks_embedding, chunks) - self.vector_index.to_index_file(self.index_dir, self.filename_prefix) - return CStatus() - class BuildVectorIndex: def __init__(self, embedding: BaseEmbedding): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 15a8fdda7..571ffde51 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -18,16 +18,10 @@ import re from typing import List, Any, Dict, Optional -from hugegraph_llm.config import llm_settings from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log -from hugegraph_llm.operators.util import init_context -from hugegraph_llm.models.llms.init_llm import get_chat_llm -from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState -from PyCGraph import GNode, CStatus - SCHEMA_EXAMPLE_PROMPT = """## Main Task Extract Triples from the given text and graph schema @@ -213,143 +207,3 @@ def _filter_long_id(self, graph) -> Dict[str, List[Any]]: if self.valid(edge["start"]) and self.valid(edge["end"]) ] return graph - - -class InfoExtractNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - return init_context(self) - - def node_init(self): - self.llm = get_chat_llm(llm_settings) - if self.wk_input.example_prompt is None: - return CStatus(-1, "Error occurs when prepare for workflow input") - self.example_prompt = self.wk_input.example_prompt - return CStatus() - - def extract_triples_by_regex_with_schema(self, schema, text): - text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") - pattern = r"\((.*?), (.*?), (.*?)\) - ([^ ]*)" - matches = re.findall(pattern, text) - - vertices_dict = {v["id"]: v for v in self.context.vertices} - for match in matches: - s, p, o, label = [item.strip() for item in match] - if None in [label, s, p, o]: - continue - # TODO: use a more efficient way to compare the extract & input property - p_lower = p.lower() - for vertex in schema["vertices"]: - if vertex["vertex_label"] == label and any( - pp.lower() == p_lower for pp in vertex["properties"] - ): - id = f"{label}-{s}" - if id not in vertices_dict: - vertices_dict[id] = { - "id": id, - "name": s, - "label": label, - "properties": {p: o}, - } - else: - vertices_dict[id]["properties"].update({p: o}) - break - for edge in schema["edges"]: - if edge["edge_label"] == label: - source_label = edge["source_vertex_label"] - source_id = f"{source_label}-{s}" - if source_id not in vertices_dict: - vertices_dict[source_id] = { - "id": source_id, - "name": s, - "label": source_label, - "properties": {}, - } - target_label = edge["target_vertex_label"] - target_id = f"{target_label}-{o}" - if target_id not in vertices_dict: - vertices_dict[target_id] = { - "id": target_id, - "name": o, - "label": target_label, - "properties": {}, - } - self.context.edges.append( - { - "start": source_id, - "end": target_id, - "type": label, - "properties": {}, - } - ) - break - self.context.vertices = list(vertices_dict.values()) - - def extract_triples_by_regex(self, text): - text = text.replace("\\n", " ").replace("\\", " ").replace("\n", " ") - pattern = r"\((.*?), (.*?), (.*?)\)" - self.context.triples += re.findall(pattern, text) - - def run(self) -> CStatus: - sts = self.node_init() - if sts.isErr(): - return sts - self.context.lock() - if self.context.chunks is None: - self.context.unlock() - raise ValueError("parameter required by extract node not found in context.") - schema = self.context.schema - chunks = self.context.chunks - - if schema: - self.context.vertices = [] - self.context.edges = [] - else: - self.context.triples = [] - - self.context.unlock() - - for sentence in chunks: - proceeded_chunk = self.extract_triples_by_llm(schema, sentence) - log.debug( - "[Legacy] %s input: %s \n output:%s", - self.__class__.__name__, - sentence, - proceeded_chunk, - ) - if schema: - self.extract_triples_by_regex_with_schema(schema, proceeded_chunk) - else: - self.extract_triples_by_regex(proceeded_chunk) - - if self.context.call_count: - self.context.call_count += len(chunks) - else: - self.context.call_count = len(chunks) - self._filter_long_id() - return CStatus() - - def extract_triples_by_llm(self, schema, chunk) -> str: - prompt = generate_extract_triple_prompt(chunk, schema) - if self.example_prompt is not None: - prompt = self.example_prompt + prompt - return self.llm.generate(prompt=prompt) - - # TODO: make 'max_length' be a configurable param in settings.py/settings.cfg - def valid(self, element_id: str, max_length: int = 256) -> bool: - if len(element_id.encode("utf-8")) >= max_length: - log.warning("Filter out GraphElementID too long: %s", element_id) - return False - return True - - def _filter_long_id(self): - self.context.vertices = [ - vertex for vertex in self.context.vertices if self.valid(vertex["id"]) - ] - self.context.edges = [ - edge - for edge in self.context.edges - if self.valid(edge["start"]) and self.valid(edge["end"]) - ] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py index 6e492b8f5..79fb33b4f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py @@ -21,16 +21,11 @@ import re from typing import List, Any, Dict -from hugegraph_llm.config import llm_settings, prompt +from hugegraph_llm.config import prompt from hugegraph_llm.document.chunk_split import ChunkSplitter from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.utils.log import log -from hugegraph_llm.operators.util import init_context -from hugegraph_llm.models.llms.init_llm import get_chat_llm -from hugegraph_llm.state.ai_state import WkFlowState, WkFlowInput -from PyCGraph import GNode, CStatus - # TODO: It is not clear whether there is any other dependence on the SCHEMA_EXAMPLE_PROMPT variable. # Because the SCHEMA_EXAMPLE_PROMPT variable will no longer change based on # prompt.extract_graph_prompt changes after the system loads, this does not seem to meet expectations. @@ -182,123 +177,3 @@ def process_items(item_list, valid_labels, item_type): "Invalid property graph JSON! Please check the extracted JSON data carefully" ) return items - - -class PropertyGraphExtractNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None - - def init(self): - self.NECESSARY_ITEM_KEYS = {"label", "type", "properties"} # pylint: disable=invalid-name - return init_context(self) - - def node_init(self): - self.llm = get_chat_llm(llm_settings) - if self.wk_input.example_prompt is None: - return CStatus(-1, "Error occurs when prepare for workflow input") - self.example_prompt = self.wk_input.example_prompt - return CStatus() - - def run(self) -> CStatus: - sts = self.node_init() - if sts.isErr(): - return sts - self.context.lock() - try: - if self.context.schema is None or self.context.chunks is None: - raise ValueError( - "parameter required by extract node not found in context." - ) - schema = self.context.schema - chunks = self.context.chunks - if self.context.vertices is None: - self.context.vertices = [] - if self.context.edges is None: - self.context.edges = [] - finally: - self.context.unlock() - - items = [] - for chunk in chunks: - proceeded_chunk = self.extract_property_graph_by_llm(schema, chunk) - log.debug( - "[LLM] %s input: %s \n output:%s", - self.__class__.__name__, - chunk, - proceeded_chunk, - ) - items.extend(self._extract_and_filter_label(schema, proceeded_chunk)) - items = filter_item(schema, items) - self.context.lock() - try: - for item in items: - if item["type"] == "vertex": - self.context.vertices.append(item) - elif item["type"] == "edge": - self.context.edges.append(item) - finally: - self.context.unlock() - self.context.call_count = (self.context.call_count or 0) + len(chunks) - return CStatus() - - def extract_property_graph_by_llm(self, schema, chunk): - prompt = generate_extract_property_graph_prompt(chunk, schema) - if self.example_prompt is not None: - prompt = self.example_prompt + prompt - return self.llm.generate(prompt=prompt) - - def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: - # Use regex to extract a JSON object with curly braces - json_match = re.search(r"({.*})", text, re.DOTALL) - if not json_match: - log.critical( - "Invalid property graph! No JSON object found, " - "please check the output format example in prompt." - ) - return [] - json_str = json_match.group(1).strip() - - items = [] - try: - property_graph = json.loads(json_str) - # Expect property_graph to be a dict with keys "vertices" and "edges" - if not ( - isinstance(property_graph, dict) - and "vertices" in property_graph - and "edges" in property_graph - ): - log.critical( - "Invalid property graph format; expecting 'vertices' and 'edges'." - ) - return items - - # Create sets for valid vertex and edge labels based on the schema - vertex_label_set = {vertex["name"] for vertex in schema["vertexlabels"]} - edge_label_set = {edge["name"] for edge in schema["edgelabels"]} - - def process_items(item_list, valid_labels, item_type): - for item in item_list: - if not isinstance(item, dict): - log.warning( - "Invalid property graph item type '%s'.", type(item) - ) - continue - if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()): - log.warning("Invalid item keys '%s'.", item.keys()) - continue - if item["label"] not in valid_labels: - log.warning( - "Invalid %s label '%s' has been ignored.", - item_type, - item["label"], - ) - continue - items.append(item) - - process_items(property_graph["vertices"], vertex_label_set, "vertex") - process_items(property_graph["edges"], edge_label_set, "edge") - except json.JSONDecodeError: - log.critical( - "Invalid property graph JSON! Please check the extracted JSON data carefully" - ) - return items diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py index 0543aa2b4..6d3418c00 100644 --- a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -25,6 +25,14 @@ class WkFlowInput(GParam): example_prompt: str = None # need by graph information extract schema: str = None # Schema information requeired by SchemaNode graph_name: str = None + data_json = None + extract_type = None + query_examples = None + few_shot_schema = None + # Fields related to PromptGenerate + source_text: str = None # Original text + scenario: str = None # Scenario description + example_name: str = None # Example name def reset(self, _: CStatus) -> None: self.texts = None @@ -33,6 +41,14 @@ def reset(self, _: CStatus) -> None: self.example_prompt = None self.schema = None self.graph_name = None + self.data_json = None + self.extract_type = None + self.query_examples = None + self.few_shot_schema = None + # PromptGenerate related configuration + self.source_text = None + self.scenario = None + self.example_name = None class WkFlowState(GParam): @@ -49,6 +65,8 @@ class WkFlowState(GParam): graph_result = None keywords_embeddings = None + generated_extract_prompt: Optional[str] = None + def setup(self): self.schema = None self.simple_schema = None @@ -63,6 +81,8 @@ def setup(self): self.graph_result = None self.keywords_embeddings = None + self.generated_extract_prompt = None + return CStatus() def to_json(self): @@ -79,3 +99,11 @@ def to_json(self): for k, v in self.__dict__.items() if not k.startswith("_") and v is not None } + + # Implement a method that assigns keys from data_json as WkFlowState member variables + def assign_from_json(self, data_json: dict): + """ + Assigns each key in the input json object as a member variable of WkFlowState. + """ + for k, v in data_json.items(): + setattr(self, k, v) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index f61b5f843..ccace69f2 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -36,6 +36,15 @@ def get_graph_index_info(): + try: + scheduler = SchedulerSingleton.get_instance() + return scheduler.schedule_flow("get_graph_index_info") + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e) + raise gr.Error(str(e)) + + +def get_graph_index_info_old(): builder = KgBuilder( LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() ) @@ -150,6 +159,15 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: def update_vid_embedding(): + scheduler = SchedulerSingleton.get_instance() + try: + return scheduler.schedule_flow("update_vid_embeddings") + except Exception as e: # pylint: disable=broad-exception-caught + log.error(e) + raise gr.Error(str(e)) + + +def update_vid_embedding_old(): builder = KgBuilder( LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() ) @@ -166,6 +184,18 @@ def update_vid_embedding(): def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: + try: + scheduler = SchedulerSingleton.get_instance() + return scheduler.schedule_flow("import_graph_data", data, schema) + except Exception as e: # pylint: disable=W0718 + log.error(e) + traceback.print_exc() + # Note: can't use gr.Error here + gr.Warning(str(e) + " Please check the graph data format/type carefully.") + return data + + +def import_graph_data_old(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: data_json = json.loads(data.strip()) log.debug("Import graph data: %s", data) @@ -190,6 +220,16 @@ def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: def build_schema(input_text, query_example, few_shot): + scheduler = SchedulerSingleton.get_instance() + try: + return scheduler.schedule_flow( + "build_schema", input_text, query_example, few_shot + ) + except (TypeError, ValueError) as e: + raise gr.Error(f"Schema generation failed: {e}") + + +def build_schema_old(input_text, query_example, few_shot): context = { "raw_texts": [input_text] if input_text else [], "query_examples": [], From 0c9a30577bd190c8fc53fc2614a6d91cfc1e470c Mon Sep 17 00:00:00 2001 From: LingXiao Qi Date: Tue, 30 Sep 2025 00:24:18 +0800 Subject: [PATCH 3/5] refactor: text2germlin with PCGraph framework (#50) Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Linyu <94553312+weijinglin@users.noreply.github.com> --- .../src/hugegraph_llm/api/admin_api.py | 8 +- .../api/exceptions/rag_exceptions.py | 4 +- .../hugegraph_llm/api/models/rag_requests.py | 106 +++++++---- .../src/hugegraph_llm/api/rag_api.py | 43 +++-- .../src/hugegraph_llm/config/admin_config.py | 2 + .../src/hugegraph_llm/config/generate.py | 4 +- .../hugegraph_llm/config/hugegraph_config.py | 1 + .../src/hugegraph_llm/config/llm_config.py | 21 +- .../config/models/base_config.py | 22 ++- .../config/models/base_prompt_config.py | 24 ++- .../src/hugegraph_llm/config/prompt_config.py | 1 + .../demo/rag_demo/admin_block.py | 39 ++-- .../src/hugegraph_llm/demo/rag_demo/app.py | 4 +- .../demo/rag_demo/configs_block.py | 91 +++------ .../demo/rag_demo/other_block.py | 13 +- .../hugegraph_llm/demo/rag_demo/rag_block.py | 74 +++++--- .../demo/rag_demo/text2gremlin_block.py | 109 ++++++++--- .../demo/rag_demo/vector_graph_block.py | 74 +++----- .../src/hugegraph_llm/document/chunk_split.py | 14 +- .../flows/get_graph_index_info.py | 4 +- .../src/hugegraph_llm/flows/graph_extract.py | 4 +- .../hugegraph_llm/flows/import_graph_data.py | 8 +- .../hugegraph_llm/flows/prompt_generate.py | 4 +- .../src/hugegraph_llm/flows/scheduler.py | 9 +- .../src/hugegraph_llm/flows/text2gremlin.py | 112 +++++++++++ .../src/hugegraph_llm/indices/graph_index.py | 17 +- .../src/hugegraph_llm/indices/vector_index.py | 33 +++- .../hugegraph_llm/middleware/middleware.py | 5 +- .../hugegraph_llm/models/embeddings/base.py | 37 ++-- .../hugegraph_llm/models/embeddings/openai.py | 25 ++- .../src/hugegraph_llm/models/llms/base.py | 34 ++-- .../src/hugegraph_llm/models/llms/init_llm.py | 6 +- .../src/hugegraph_llm/models/llms/litellm.py | 10 +- .../src/hugegraph_llm/models/llms/ollama.py | 30 ++- .../src/hugegraph_llm/models/llms/openai.py | 4 +- .../hugegraph_llm/models/rerankers/cohere.py | 9 +- .../models/rerankers/init_reranker.py | 4 +- .../models/rerankers/siliconflow.py | 9 +- .../nodes/hugegraph_node/gremlin_execute.py | 68 +++++++ .../nodes/hugegraph_node/schema.py | 2 +- .../index_node/gremlin_example_index_query.py | 49 +++++ .../nodes/llm_node/schema_build.py | 8 +- .../nodes/llm_node/text2gremlin.py | 70 +++++++ .../operators/common_op/check_schema.py | 40 +--- .../operators/common_op/merge_dedup_rerank.py | 15 +- .../operators/document_op/word_extract.py | 3 +- .../hugegraph_llm/operators/graph_rag_task.py | 4 +- .../hugegraph_op/commit_to_hugegraph.py | 58 ++---- .../operators/hugegraph_op/graph_rag_query.py | 63 ++++-- .../operators/hugegraph_op/schema_manager.py | 4 +- .../index_op/build_gremlin_example_index.py | 14 +- .../index_op/build_semantic_index.py | 30 +-- .../operators/index_op/build_vector_index.py | 4 +- .../index_op/gremlin_example_index_query.py | 29 ++- .../operators/index_op/semantic_id_query.py | 33 ++-- .../operators/index_op/vector_index_query.py | 8 +- .../operators/kg_construction_task.py | 11 +- .../operators/llm_op/answer_synthesize.py | 179 ++++++++++++------ .../operators/llm_op/disambiguate_data.py | 3 +- .../operators/llm_op/gremlin_generate.py | 19 +- .../operators/llm_op/info_extract.py | 8 +- .../operators/llm_op/keyword_extract.py | 24 +++ .../operators/llm_op/prompt_generate.py | 6 +- .../llm_op/property_graph_extract.py | 18 +- .../operators/llm_op/schema_build.py | 14 +- .../src/hugegraph_llm/state/ai_state.py | 30 ++- .../src/hugegraph_llm/utils/anchor.py | 9 +- .../src/hugegraph_llm/utils/decorators.py | 1 + .../hugegraph_llm/utils/embedding_utils.py | 9 +- .../hugegraph_llm/utils/graph_index_utils.py | 40 +--- .../hugegraph_llm/utils/hugegraph_utils.py | 26 ++- hugegraph-llm/src/hugegraph_llm/utils/log.py | 2 +- .../hugegraph_llm/utils/vector_index_utils.py | 16 +- hugegraph-llm/src/tests/config/test_config.py | 1 + .../embeddings/test_openai_embedding.py | 1 + .../tests/models/llms/test_ollama_client.py | 7 +- .../operators/common_op/test_check_schema.py | 9 +- .../operators/common_op/test_nltk_helper.py | 1 + .../src/pyhugegraph/api/auth.py | 16 +- .../src/pyhugegraph/api/graph.py | 12 +- .../src/pyhugegraph/api/schema.py | 12 +- .../api/schema_manage/index_label.py | 12 +- .../src/pyhugegraph/api/services.py | 8 +- .../src/pyhugegraph/api/traverser.py | 45 ++--- .../src/pyhugegraph/client.py | 2 +- .../pyhugegraph/example/hugegraph_example.py | 14 +- .../structure/property_key_data.py | 4 +- .../src/pyhugegraph/utils/huge_config.py | 10 +- .../src/pyhugegraph/utils/huge_router.py | 8 +- .../src/pyhugegraph/utils/log.py | 4 +- .../src/pyhugegraph/utils/util.py | 23 ++- .../src/tests/api/test_auth.py | 8 +- .../src/tests/api/test_version.py | 8 +- .../src/tests/client_utils.py | 6 +- 94 files changed, 1330 insertions(+), 816 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/gremlin_execute.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py diff --git a/hugegraph-llm/src/hugegraph_llm/api/admin_api.py b/hugegraph-llm/src/hugegraph_llm/api/admin_api.py index 05648d48e..4c192c29c 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/admin_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/admin_api.py @@ -31,8 +31,12 @@ def admin_http_api(router: APIRouter, log_stream): @router.post("/logs", status_code=status.HTTP_200_OK) async def log_stream_api(req: LogStreamRequest): if admin_settings.admin_token != req.admin_token: - raise generate_response(RAGResponse(status_code=status.HTTP_403_FORBIDDEN, #pylint: disable=E0702 - message="Invalid admin_token")) + raise generate_response( + RAGResponse( + status_code=status.HTTP_403_FORBIDDEN, # pylint: disable=E0702 + message="Invalid admin_token", + ) + ) log_path = os.path.join("logs", req.log_file) # Create a StreamingResponse that reads from the log stream generator diff --git a/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py index 75eb14cf3..18723e30b 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py +++ b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py @@ -21,7 +21,9 @@ class ExternalException(HTTPException): def __init__(self): - super().__init__(status_code=400, detail="Connect failed with error code -1, please check the input.") + super().__init__( + status_code=400, detail="Connect failed with error code -1, please check the input." + ) class ConnectionFailedException(HTTPException): diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index cf227e8bd..f46aea02c 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -24,10 +24,10 @@ class GraphConfigRequest(BaseModel): - url: str = Query('127.0.0.1:8080', description="hugegraph client url.") - graph: str = Query('hugegraph', description="hugegraph client name.") - user: str = Query('', description="hugegraph client user.") - pwd: str = Query('', description="hugegraph client pwd.") + url: str = Query("127.0.0.1:8080", description="hugegraph client url.") + graph: str = Query("hugegraph", description="hugegraph client name.") + user: str = Query("", description="hugegraph client user.") + pwd: str = Query("", description="hugegraph client pwd.") gs: str = None @@ -36,22 +36,42 @@ class RAGRequest(BaseModel): raw_answer: bool = Query(False, description="Use LLM to generate answer directly") vector_only: bool = Query(False, description="Use LLM to generate answer with vector") graph_only: bool = Query(True, description="Use LLM to generate answer with graph RAG only") - graph_vector_answer: bool = Query(False, description="Use LLM to generate answer with vector & GraphRAG") + graph_vector_answer: bool = Query( + False, description="Use LLM to generate answer with vector & GraphRAG" + ) graph_ratio: float = Query(0.5, description="The ratio of GraphRAG ans & vector ans") - rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.") - near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.") - custom_priority_info: str = Query("", description="Custom information to prioritize certain results.") + rerank_method: Literal["bleu", "reranker"] = Query( + "bleu", description="Method to rerank the results." + ) + near_neighbor_first: bool = Query( + False, description="Prioritize near neighbors in the search results." + ) + custom_priority_info: str = Query( + "", description="Custom information to prioritize certain results." + ) # Graph Configs - max_graph_items: int = Query(30, description="Maximum number of items for GQL queries in graph.") + max_graph_items: int = Query( + 30, description="Maximum number of items for GQL queries in graph." + ) topk_return_results: int = Query(20, description="Number of sorted results to return finally.") - vector_dis_threshold: float = Query(0.9, description="Threshold for vector similarity\ - (results greater than this will be ignored).") - topk_per_keyword: int = Query(1, description="TopK results returned for each keyword \ - extracted from the query, by default only the most similar one is returned.") - client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.") + vector_dis_threshold: float = Query( + 0.9, + description="Threshold for vector similarity\ + (results greater than this will be ignored).", + ) + topk_per_keyword: int = Query( + 1, + description="TopK results returned for each keyword \ + extracted from the query, by default only the most similar one is returned.", + ) + client_config: Optional[GraphConfigRequest] = Query( + None, description="hugegraph server config." + ) # Keep prompt params in the end - answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.") + answer_prompt: Optional[str] = Query( + prompt.answer_prompt, description="Prompt to guide the answer generation." + ) keywords_extract_prompt: Optional[str] = Query( prompt.keywords_extract_prompt, description="Prompt for extracting keywords from query.", @@ -67,22 +87,39 @@ class RAGRequest(BaseModel): class GraphRAGRequest(BaseModel): query: str = Query(..., description="Query you want to ask") # Graph Configs - max_graph_items: int = Query(30, description="Maximum number of items for GQL queries in graph.") + max_graph_items: int = Query( + 30, description="Maximum number of items for GQL queries in graph." + ) topk_return_results: int = Query(20, description="Number of sorted results to return finally.") - vector_dis_threshold: float = Query(0.9, description="Threshold for vector similarity \ - (results greater than this will be ignored).") - topk_per_keyword: int = Query(1, description="TopK results returned for each keyword extracted\ - from the query, by default only the most similar one is returned.") + vector_dis_threshold: float = Query( + 0.9, + description="Threshold for vector similarity \ + (results greater than this will be ignored).", + ) + topk_per_keyword: int = Query( + 1, + description="TopK results returned for each keyword extracted\ + from the query, by default only the most similar one is returned.", + ) - client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.") + client_config: Optional[GraphConfigRequest] = Query( + None, description="hugegraph server config." + ) get_vertex_only: bool = Query(False, description="return only keywords & vertex (early stop).") gremlin_tmpl_num: int = Query( - 1, description="Number of Gremlin templates to use. If num <=0 means template is not provided" + 1, + description="Number of Gremlin templates to use. If num <=0 means template is not provided", + ) + rerank_method: Literal["bleu", "reranker"] = Query( + "bleu", description="Method to rerank the results." + ) + near_neighbor_first: bool = Query( + False, description="Prioritize near neighbors in the search results." + ) + custom_priority_info: str = Query( + "", description="Custom information to prioritize certain results." ) - rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.") - near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.") - custom_priority_info: str = Query("", description="Custom information to prioritize certain results.") gremlin_prompt: Optional[str] = Query( prompt.gremlin_generate_prompt, description="Prompt for the Text2Gremlin query.", @@ -115,6 +152,7 @@ class LogStreamRequest(BaseModel): admin_token: Optional[str] = None log_file: Optional[str] = "llm-server.log" + class GremlinOutputType(str, Enum): MATCH_RESULT = "match_result" TEMPLATE_GREMLIN = "template_gremlin" @@ -122,32 +160,36 @@ class GremlinOutputType(str, Enum): TEMPLATE_EXECUTION_RESULT = "template_execution_result" RAW_EXECUTION_RESULT = "raw_execution_result" + class GremlinGenerateRequest(BaseModel): query: str example_num: Optional[int] = Query( - 0, - description="Number of Gremlin templates to use.(0 means no templates)" + 0, description="Number of Gremlin templates to use.(0 means no templates)" ) gremlin_prompt: Optional[str] = Query( prompt.gremlin_generate_prompt, description="Prompt for the Text2Gremlin query.", ) - client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.") + client_config: Optional[GraphConfigRequest] = Query( + None, description="hugegraph server config." + ) output_types: Optional[List[GremlinOutputType]] = Query( default=[GremlinOutputType.TEMPLATE_GREMLIN], description=""" a list can contain "match_result","template_gremlin", "raw_gremlin","template_execution_result","raw_execution_result" You can specify which type of result do you need. Empty means all types. - """ + """, ) - @field_validator('gremlin_prompt') + @field_validator("gremlin_prompt") @classmethod def validate_prompt_placeholders(cls, v): if v is not None: - required_placeholders = ['{query}', '{schema}', '{example}', '{vertices}'] + required_placeholders = ["{query}", "{schema}", "{example}", "{vertices}"] missing = [p for p in required_placeholders if p not in v] if missing: - raise ValueError(f"Prompt template is missing required placeholders: {', '.join(missing)}") + raise ValueError( + f"Prompt template is missing required placeholders: {', '.join(missing)}" + ) return v diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index bfa76e7ef..356176e4e 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -32,6 +32,8 @@ from hugegraph_llm.config import huge_settings from hugegraph_llm.config import llm_settings, prompt from hugegraph_llm.utils.log import log +from hugegraph_llm.flows.scheduler import SchedulerSingleton + # pylint: disable=too-many-statements @@ -74,7 +76,9 @@ def rag_answer_api(req: RAGRequest): "query": req.query, **{ key: value - for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result) + for key, value in zip( + ["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result + ) if getattr(req, key) }, } @@ -103,11 +107,12 @@ def graph_rag_recall_api(req: GraphRAGRequest): near_neighbor_first=req.near_neighbor_first, custom_related_information=req.custom_priority_info, gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, - get_vertex_only=req.get_vertex_only + get_vertex_only=req.get_vertex_only, ) if req.get_vertex_only: from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery + graph_rag = GraphRAGQuery() graph_rag.init_client(result) vertex_details = graph_rag.get_vertex_details(result["match_vids"]) @@ -135,7 +140,8 @@ def graph_rag_recall_api(req: GraphRAGRequest): except Exception as e: log.error("Unexpected error occurred: %s", e) raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred." + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="An unexpected error occurred.", ) from e @router.post("/config/graph", status_code=status.HTTP_201_CREATED) @@ -150,7 +156,9 @@ def llm_config_api(req: LLMConfigRequest): llm_settings.llm_type = req.llm_type if req.llm_type == "openai": - res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http") + res = apply_llm_conf( + req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http" + ) else: res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) @@ -160,7 +168,9 @@ def embedding_config_api(req: LLMConfigRequest): llm_settings.embedding_type = req.llm_type if req.llm_type == "openai": - res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http") + res = apply_embedding_conf( + req.api_key, req.api_base, req.language_model, origin_call="http" + ) else: res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) @@ -170,7 +180,9 @@ def rerank_config_api(req: RerankerConfigRequest): llm_settings.reranker_type = req.reranker_type if req.reranker_type == "cohere": - res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http") + res = apply_reranker_conf( + req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http" + ) elif req.reranker_type == "siliconflow": res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http") else: @@ -182,16 +194,23 @@ def text2gremlin_api(req: GremlinGenerateRequest): try: set_graph_config(req) + # Basic parameter validation: empty query => 400 + if not req.query or not str(req.query).strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Query must not be empty." + ) + output_types_str_list = None if req.output_types: output_types_str_list = [ot.value for ot in req.output_types] - response_dict = gremlin_generate_selective_func( - inp=req.query, - example_num=req.example_num, - schema_input=huge_settings.graph_name, - gremlin_prompt_input=req.gremlin_prompt, - requested_outputs=output_types_str_list, + response_dict = SchedulerSingleton.get_instance().schedule_flow( + "text2gremlin", + req.query, + req.example_num, + huge_settings.graph_name, + req.gremlin_prompt, + output_types_str_list, ) return response_dict except HTTPException as e: diff --git a/hugegraph-llm/src/hugegraph_llm/config/admin_config.py b/hugegraph-llm/src/hugegraph_llm/config/admin_config.py index b2814de41..fabc75de4 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/admin_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/admin_config.py @@ -18,8 +18,10 @@ from typing import Optional from .models import BaseConfig + class AdminConfig(BaseConfig): """Admin settings""" + enable_login: Optional[str] = "False" user_token: Optional[str] = "4321" admin_token: Optional[str] = "xxxx" diff --git a/hugegraph-llm/src/hugegraph_llm/config/generate.py b/hugegraph-llm/src/hugegraph_llm/config/generate.py index 36910e480..4b40e899f 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/generate.py +++ b/hugegraph-llm/src/hugegraph_llm/config/generate.py @@ -22,7 +22,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate hugegraph-llm config file") - parser.add_argument("-U", "--update", default=True, action="store_true", help="Update the config file") + parser.add_argument( + "-U", "--update", default=True, action="store_true", help="Update the config file" + ) args = parser.parse_args() if args.update: huge_settings.generate_env() diff --git a/hugegraph-llm/src/hugegraph_llm/config/hugegraph_config.py b/hugegraph-llm/src/hugegraph_llm/config/hugegraph_config.py index e51008d96..69abf0fbc 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/hugegraph_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/hugegraph_config.py @@ -21,6 +21,7 @@ class HugeGraphConfig(BaseConfig): """HugeGraph settings""" + # graph server config graph_url: Optional[str] = "127.0.0.1:8080" graph_name: Optional[str] = "hugegraph" diff --git a/hugegraph-llm/src/hugegraph_llm/config/llm_config.py b/hugegraph-llm/src/hugegraph_llm/config/llm_config.py index 64d851f5a..eb094ef88 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/llm_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/llm_config.py @@ -24,6 +24,7 @@ class LLMConfig(BaseConfig): """LLM settings""" + language: Literal["EN", "CN"] = "EN" chat_llm_type: Literal["openai", "litellm", "ollama/local"] = "openai" extract_llm_type: Literal["openai", "litellm", "ollama/local"] = "openai" @@ -35,23 +36,33 @@ class LLMConfig(BaseConfig): hybrid_llm_weights: Optional[float] = 0.5 # TODO: divide RAG part if necessary # 1. OpenAI settings - openai_chat_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + openai_chat_api_base: Optional[str] = os.environ.get( + "OPENAI_BASE_URL", "https://api.openai.com/v1" + ) openai_chat_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_chat_language_model: Optional[str] = "gpt-4.1-mini" - openai_extract_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + openai_extract_api_base: Optional[str] = os.environ.get( + "OPENAI_BASE_URL", "https://api.openai.com/v1" + ) openai_extract_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_extract_language_model: Optional[str] = "gpt-4.1-mini" - openai_text2gql_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + openai_text2gql_api_base: Optional[str] = os.environ.get( + "OPENAI_BASE_URL", "https://api.openai.com/v1" + ) openai_text2gql_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") openai_text2gql_language_model: Optional[str] = "gpt-4.1-mini" - openai_embedding_api_base: Optional[str] = os.environ.get("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1") + openai_embedding_api_base: Optional[str] = os.environ.get( + "OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1" + ) openai_embedding_api_key: Optional[str] = os.environ.get("OPENAI_EMBEDDING_API_KEY") openai_embedding_model: Optional[str] = "text-embedding-3-small" openai_chat_tokens: int = 8192 openai_extract_tokens: int = 256 openai_text2gql_tokens: int = 4096 # 2. Rerank settings - cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank") + cohere_base_url: Optional[str] = os.environ.get( + "CO_API_URL", "https://api.cohere.com/v1/rerank" + ) reranker_api_key: Optional[str] = None reranker_model: Optional[str] = None # 3. Ollama settings diff --git a/hugegraph-llm/src/hugegraph_llm/config/models/base_config.py b/hugegraph-llm/src/hugegraph_llm/config/models/base_config.py index dfe9d1056..5fec3a778 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/models/base_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/models/base_config.py @@ -31,12 +31,15 @@ class BaseConfig(BaseSettings): class Config: env_file = env_path case_sensitive = False - extra = 'ignore' # ignore extra fields to avoid ValidationError + extra = "ignore" # ignore extra fields to avoid ValidationError env_ignore_empty = True def generate_env(self): if os.path.exists(env_path): - log.info("%s already exists, do you want to override with the default configuration? (y/n)", env_path) + log.info( + "%s already exists, do you want to override with the default configuration? (y/n)", + env_path, + ) update = input() if update.lower() != "y": return @@ -96,8 +99,12 @@ def _sync_env_to_object(self, env_config, config_dict): obj_value_str = str(obj_value) if obj_value is not None else "" if env_value != obj_value_str: - log.info("Update configuration from the file: %s=%s (Original value: %s)", - env_key, env_value, obj_value_str) + log.info( + "Update configuration from the file: %s=%s (Original value: %s)", + env_key, + env_value, + obj_value_str, + ) # Update the object attribute (using lowercase key) setattr(self, env_key.lower(), env_value) @@ -106,8 +113,11 @@ def _sync_object_to_env(self, env_config, config_dict): for obj_key, obj_value in config_dict.items(): if obj_key not in env_config: obj_value_str = str(obj_value) if obj_value is not None else "" - log.info("Add configuration items to the environment variable file: %s=%s", - obj_key, obj_value) + log.info( + "Add configuration items to the environment variable file: %s=%s", + obj_key, + obj_value, + ) # Add to .env set_key(env_path, obj_key, obj_value_str, quote_mode="never") diff --git a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py index 1008c3c13..4b0c4dc76 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/models/base_prompt_config.py @@ -32,11 +32,14 @@ class LiteralStr(str): pass + def literal_str_representer(dumper, data): - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + yaml.add_representer(LiteralStr, literal_str_representer) + class BasePromptConfig: graph_schema: str = "" extract_graph_prompt: str = "" @@ -54,9 +57,7 @@ def ensure_yaml_file_exists(self): current_dir = Path.cwd().resolve() project_root = get_project_root() if current_dir == project_root: - log.info( - "Current working directory is the project root, proceeding to run the app." - ) + log.info("Current working directory is the project root, proceeding to run the app.") else: error_msg = ( f"Current working directory is not the project root. " @@ -74,16 +75,20 @@ def ensure_yaml_file_exists(self): setattr(self, key, value) # Check if the language in the .env file matches the language in the YAML file - env_lang = (self.llm_settings.language.lower() - if hasattr(self, 'llm_settings') and self.llm_settings.language - else 'en') - yaml_lang = data.get('_language_generated', 'en').lower() + env_lang = ( + self.llm_settings.language.lower() + if hasattr(self, "llm_settings") and self.llm_settings.language + else "en" + ) + yaml_lang = data.get("_language_generated", "en").lower() if env_lang.strip() != yaml_lang.strip(): log.warning( "Prompt was changed '.env' language is '%s', " "but '%s' was generated for '%s'. " "Regenerating the prompt file...", - env_lang, F_NAME, yaml_lang + env_lang, + F_NAME, + yaml_lang, ) if self.llm_settings.language.lower() == "cn": self.answer_prompt = self.answer_prompt_CN @@ -105,6 +110,7 @@ def save_to_yaml(self): def to_literal(val): return LiteralStr(val) if isinstance(val, str) else val + data = { "graph_schema": to_literal(self.graph_schema), "text2gql_graph_schema": to_literal(self.text2gql_graph_schema), diff --git a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py index eaccbefa2..cc79b3cef 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py +++ b/hugegraph-llm/src/hugegraph_llm/config/prompt_config.py @@ -23,6 +23,7 @@ class PromptConfig(BasePromptConfig): def __init__(self, llm_config_object): self.llm_settings = llm_config_object + # Data is detached from llm_op/answer_synthesize.py answer_prompt_EN: str = """You are an expert in the fields of knowledge graphs and natural language processing. diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py index 2d5937a43..1b2032b23 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/admin_block.py @@ -30,7 +30,7 @@ async def log_stream(log_path: str, lines: int = 125): Stream the content of a log file like `tail -f`. """ try: - with open(log_path, 'r', encoding='utf-8') as file: + with open(log_path, "r", encoding="utf-8") as file: buffer = deque(file, maxlen=lines) for line in buffer: yield line # Yield the initial lines @@ -50,8 +50,8 @@ async def log_stream(log_path: str, lines: int = 125): def read_llm_server_log(lines=250): log_path = "logs/llm-server.log" try: - with open(log_path, "r", encoding='utf-8', errors="replace") as f: - return ''.join(deque(f, maxlen=lines)) + with open(log_path, "r", encoding="utf-8", errors="replace") as f: + return "".join(deque(f, maxlen=lines)) except FileNotFoundError: log.critical("Log file not found: %s", log_path) return "LLM Server log file not found." @@ -61,10 +61,10 @@ def read_llm_server_log(lines=250): def clear_llm_server_log(): log_path = "logs/llm-server.log" try: - with open(log_path, "w", encoding='utf-8') as f: + with open(log_path, "w", encoding="utf-8") as f: f.truncate(0) # Clear the contents of the file return "LLM Server log cleared." - except Exception as e: #pylint: disable=W0718 + except Exception as e: # pylint: disable=W0718 log.error("An error occurred while clearing the log: %s", str(e)) return "Failed to clear LLM Server log." @@ -84,7 +84,7 @@ def check_password(password, request: Request = None): gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), - gr.update(visible=False) + gr.update(visible=False), ) # Log the failed attempt with IP address log.error("Incorrect password attempt from IP: %s", client_ip) @@ -93,7 +93,7 @@ def check_password(password, request: Request = None): gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), - gr.update(value="Incorrect password. Access denied.", visible=True) + gr.update(value="Incorrect password. Access denied.", visible=True), ) @@ -110,10 +110,7 @@ def create_admin_block(): # Error message box, initially hidden error_message = gr.Textbox( - label="", - visible=False, - interactive=False, - elem_classes="error-message" + label="", visible=False, interactive=False, elem_classes="error-message" ) # Button to submit password @@ -136,26 +133,32 @@ def create_admin_block(): clear_llm_server_button = gr.Button("Clear LLM Server Log", visible=False) with gr.Column(): # Button to refresh LLM Server log manually - refresh_llm_server_button = gr.Button("Refresh LLM Server Log", visible=False, - variant="primary") + refresh_llm_server_button = gr.Button( + "Refresh LLM Server Log", visible=False, variant="primary" + ) # Define what happens when the password is submitted - submit_button.click( #pylint: disable=E1101 + submit_button.click( # pylint: disable=E1101 fn=check_password, inputs=[password_input], - outputs=[llm_server_log_output, hidden_row, clear_llm_server_button, - refresh_llm_server_button, error_message], + outputs=[ + llm_server_log_output, + hidden_row, + clear_llm_server_button, + refresh_llm_server_button, + error_message, + ], ) # Define what happens when the Clear LLM Server Log button is clicked - clear_llm_server_button.click( #pylint: disable=E1101 + clear_llm_server_button.click( # pylint: disable=E1101 fn=clear_llm_server_log, inputs=[], outputs=[llm_server_log_output], ) # Define what happens when the Refresh LLM Server Log button is clicked - refresh_llm_server_button.click( #pylint: disable=E1101 + refresh_llm_server_button.click( # pylint: disable=E1101 fn=read_llm_server_log, inputs=[], outputs=[llm_server_log_output], diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py index 4e575dddd..4e3f4de39 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/app.py @@ -166,9 +166,7 @@ def create_app(): # settings.check_env() prompt.update_yaml_file() auth_enabled = admin_settings.enable_login.lower() == "true" - log.info( - "(Status) Authentication is %s now.", "enabled" if auth_enabled else "disabled" - ) + log.info("(Status) Authentication is %s now.", "enabled" if auth_enabled else "disabled") api_auth = APIRouter(dependencies=[Depends(authenticate)] if auth_enabled else []) hugegraph_llm = init_rag_ui() diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py index 01ea24aa8..8c595c30d 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py @@ -71,9 +71,7 @@ def test_api_connection( log.debug("Request URL: %s", url) try: if method.upper() == "GET": - resp = requests.get( - url, headers=headers, params=params, timeout=(1.0, 5.0), auth=auth - ) + resp = requests.get(url, headers=headers, params=params, timeout=(1.0, 5.0), auth=auth) elif method.upper() == "POST": resp = requests.post( url, @@ -125,9 +123,7 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: llm_settings.ollama_embedding_host = arg1 llm_settings.ollama_embedding_port = int(arg2) llm_settings.ollama_embedding_model = arg3 - status_code = test_api_connection( - f"http://{arg1}:{arg2}", origin_call=origin_call - ) + status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call) elif embedding_option == "litellm": llm_settings.litellm_embedding_api_key = arg1 llm_settings.litellm_embedding_api_base = arg2 @@ -218,8 +214,7 @@ def apply_llm_config( setattr(llm_settings, f"openai_{current_llm_config}_tokens", int(max_tokens)) test_url = ( - getattr(llm_settings, f"openai_{current_llm_config}_api_base") - + "/chat/completions" + getattr(llm_settings, f"openai_{current_llm_config}_api_base") + "/chat/completions" ) data = { "model": model_name, @@ -233,9 +228,7 @@ def apply_llm_config( elif llm_option == "ollama/local": setattr(llm_settings, f"ollama_{current_llm_config}_host", api_key_or_host) - setattr( - llm_settings, f"ollama_{current_llm_config}_port", int(api_base_or_port) - ) + setattr(llm_settings, f"ollama_{current_llm_config}_port", int(api_base_or_port)) setattr(llm_settings, f"ollama_{current_llm_config}_language_model", model_name) status_code = test_api_connection( f"http://{api_key_or_host}:{api_base_or_port}", origin_call=origin_call @@ -243,12 +236,8 @@ def apply_llm_config( elif llm_option == "litellm": setattr(llm_settings, f"litellm_{current_llm_config}_api_key", api_key_or_host) - setattr( - llm_settings, f"litellm_{current_llm_config}_api_base", api_base_or_port - ) - setattr( - llm_settings, f"litellm_{current_llm_config}_language_model", model_name - ) + setattr(llm_settings, f"litellm_{current_llm_config}_api_base", api_base_or_port) + setattr(llm_settings, f"litellm_{current_llm_config}_language_model", model_name) setattr(llm_settings, f"litellm_{current_llm_config}_tokens", int(max_tokens)) status_code = test_litellm_chat( @@ -295,7 +284,9 @@ def create_configs_block() -> list: ), ] graph_config_button = gr.Button("Apply Configuration") - graph_config_button.click(apply_graph_config, inputs=graph_config_input) # pylint: disable=no-member + graph_config_button.click( + apply_graph_config, inputs=graph_config_input + ) # pylint: disable=no-member # TODO : use OOP to refactor the following code with gr.Accordion("2. Set up the LLM.", open=False): @@ -373,13 +364,9 @@ def chat_llm_settings(llm_type): ), ] else: - llm_config_input = [ - gr.Textbox(value="", visible=False) for _ in range(4) - ] + llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)] llm_config_button = gr.Button("Apply configuration") - llm_config_button.click( - apply_llm_config_with_chat_op, inputs=llm_config_input - ) + llm_config_button.click(apply_llm_config_with_chat_op, inputs=llm_config_input) # Determine whether there are Settings in the.env file env_path = os.path.join( os.getcwd(), ".env" @@ -419,9 +406,7 @@ def extract_llm_settings(llm_type): label="api_base", ), gr.Textbox( - value=getattr( - llm_settings, "openai_extract_language_model" - ), + value=getattr(llm_settings, "openai_extract_language_model"), label="model_name", ), gr.Textbox( @@ -440,9 +425,7 @@ def extract_llm_settings(llm_type): label="port", ), gr.Textbox( - value=getattr( - llm_settings, "ollama_extract_language_model" - ), + value=getattr(llm_settings, "ollama_extract_language_model"), label="model_name", ), gr.Textbox(value="", visible=False), @@ -460,9 +443,7 @@ def extract_llm_settings(llm_type): info="If you want to use the default api_base, please keep it blank", ), gr.Textbox( - value=getattr( - llm_settings, "litellm_extract_language_model" - ), + value=getattr(llm_settings, "litellm_extract_language_model"), label="model_name", info="Please refer to https://docs.litellm.ai/docs/providers", ), @@ -472,13 +453,9 @@ def extract_llm_settings(llm_type): ), ] else: - llm_config_input = [ - gr.Textbox(value="", visible=False) for _ in range(4) - ] + llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)] llm_config_button = gr.Button("Apply configuration") - llm_config_button.click( - apply_llm_config_with_extract_op, inputs=llm_config_input - ) + llm_config_button.click(apply_llm_config_with_extract_op, inputs=llm_config_input) with gr.Tab(label="text2gql"): text2gql_llm_dropdown = gr.Dropdown( @@ -503,9 +480,7 @@ def text2gql_llm_settings(llm_type): label="api_base", ), gr.Textbox( - value=getattr( - llm_settings, "openai_text2gql_language_model" - ), + value=getattr(llm_settings, "openai_text2gql_language_model"), label="model_name", ), gr.Textbox( @@ -524,9 +499,7 @@ def text2gql_llm_settings(llm_type): label="port", ), gr.Textbox( - value=getattr( - llm_settings, "ollama_text2gql_language_model" - ), + value=getattr(llm_settings, "ollama_text2gql_language_model"), label="model_name", ), gr.Textbox(value="", visible=False), @@ -544,9 +517,7 @@ def text2gql_llm_settings(llm_type): info="If you want to use the default api_base, please keep it blank", ), gr.Textbox( - value=getattr( - llm_settings, "litellm_text2gql_language_model" - ), + value=getattr(llm_settings, "litellm_text2gql_language_model"), label="model_name", info="Please refer to https://docs.litellm.ai/docs/providers", ), @@ -556,13 +527,9 @@ def text2gql_llm_settings(llm_type): ), ] else: - llm_config_input = [ - gr.Textbox(value="", visible=False) for _ in range(4) - ] + llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)] llm_config_button = gr.Button("Apply configuration") - llm_config_button.click( - apply_llm_config_with_text2gql_op, inputs=llm_config_input - ) + llm_config_button.click(apply_llm_config_with_text2gql_op, inputs=llm_config_input) with gr.Accordion("3. Set up the Embedding.", open=False): embedding_dropdown = gr.Dropdown( @@ -594,12 +561,8 @@ def embedding_settings(embedding_type): elif embedding_type == "ollama/local": with gr.Row(): embedding_config_input = [ - gr.Textbox( - value=llm_settings.ollama_embedding_host, label="host" - ), - gr.Textbox( - value=str(llm_settings.ollama_embedding_port), label="port" - ), + gr.Textbox(value=llm_settings.ollama_embedding_host, label="host"), + gr.Textbox(value=str(llm_settings.ollama_embedding_port), label="port"), gr.Textbox( value=llm_settings.ollama_embedding_model, label="model_name", @@ -648,9 +611,7 @@ def embedding_settings(embedding_type): @gr.render(inputs=[reranker_dropdown]) def reranker_settings(reranker_type): - llm_settings.reranker_type = ( - reranker_type if reranker_type != "None" else None - ) + llm_settings.reranker_type = reranker_type if reranker_type != "None" else None if reranker_type == "cohere": with gr.Row(): reranker_config_input = [ @@ -660,9 +621,7 @@ def reranker_settings(reranker_type): type="password", ), gr.Textbox(value=llm_settings.reranker_model, label="model"), - gr.Textbox( - value=llm_settings.cohere_base_url, label="base_url" - ), + gr.Textbox(value=llm_settings.cohere_base_url, label="base_url"), ] elif reranker_type == "siliconflow": with gr.Row(): diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py index da10f50f4..8b78328f3 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/other_block.py @@ -31,7 +31,9 @@ def create_other_block(): gr.Markdown("""## Other Tools """) with gr.Row(): - inp = gr.Textbox(value="g.V().limit(10)", label="Gremlin query", show_copy_button=True, lines=8) + inp = gr.Textbox( + value="g.V().limit(10)", label="Gremlin query", show_copy_button=True, lines=8 + ) out = gr.Code(label="Output", language="json", elem_classes="code-container-show") btn = gr.Button("Run Gremlin query") btn.click(fn=run_gremlin_query, inputs=[inp], outputs=out) # pylint: disable=no-member @@ -39,7 +41,9 @@ def create_other_block(): gr.Markdown("---") with gr.Row(): inp = [] - out = gr.Textbox(label="Backup Graph Manually (Auto backup at 1:00 AM everyday)", show_copy_button=True) + out = gr.Textbox( + label="Backup Graph Manually (Auto backup at 1:00 AM everyday)", show_copy_button=True + ) btn = gr.Button("Backup Graph Data") btn.click(fn=backup_data, inputs=inp, outputs=out) # pylint: disable=no-member with gr.Accordion("Init HugeGraph test data (🚧)", open=False): @@ -55,10 +59,7 @@ async def lifespan(app: FastAPI): # pylint: disable=W0621 log.info("Starting background scheduler...") scheduler = AsyncIOScheduler() scheduler.add_job( - backup_data, - trigger=CronTrigger(hour=1, minute=0), - id="daily_backup", - replace_existing=True + backup_data, trigger=CronTrigger(hour=1, minute=0), id="daily_backup", replace_existing=True ) scheduler.start() diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index 982436b0f..8f70c34bd 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -91,7 +91,9 @@ def rag_answer( near_neighbor_first=near_neighbor_first, topk_return_results=topk_return_results, ) - rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) + rag.synthesize_answer( + raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt + ) try: context = rag.run( @@ -146,6 +148,7 @@ def update_ui_configs( graph_search = graph_only_answer or graph_vector_answer return graph_search, gremlin_prompt, vector_search + async def rag_answer_streaming( text: str, raw_answer: bool, @@ -188,9 +191,9 @@ async def rag_answer_streaming( if vector_search: rag.query_vector_index() if graph_search: - rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema( - huge_settings.graph_name - ).query_graphdb( + rag.extract_keywords( + extract_template=keywords_extract_prompt + ).keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb( num_gremlin_generate_example=gremlin_tmpl_num, gremlin_prompt=gremlin_prompt, ) @@ -202,7 +205,9 @@ async def rag_answer_streaming( # rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) try: - context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search) + context = rag.run( + verbose=True, query=text, vector_search=vector_search, graph_search=graph_search + ) if context.get("switch_to_bleu"): gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") answer_synthesize = AnswerSynthesize( @@ -228,6 +233,7 @@ async def rag_answer_streaming( log.critical(e) raise gr.Error(f"An unexpected error occurred: {str(e)}") + @with_task_id def create_rag_block(): # pylint: disable=R0915 (too-many-statements),C0301 @@ -235,7 +241,9 @@ def create_rag_block(): with gr.Row(): with gr.Column(scale=2): # with gr.Blocks().queue(max_size=20, default_concurrency_limit=5): - inp = gr.Textbox(value=prompt.default_question, label="Question", show_copy_button=True, lines=3) + inp = gr.Textbox( + value=prompt.default_question, label="Question", show_copy_button=True, lines=3 + ) # TODO: Only support inline formula now. Should support block formula gr.Markdown("Basic LLM Answer", elem_classes="output-box-label") @@ -275,10 +283,16 @@ def create_rag_block(): with gr.Column(scale=1): with gr.Row(): raw_radio = gr.Radio(choices=[True, False], value=False, label="Basic LLM Answer") - vector_only_radio = gr.Radio(choices=[True, False], value=False, label="Vector-only Answer") + vector_only_radio = gr.Radio( + choices=[True, False], value=False, label="Vector-only Answer" + ) with gr.Row(): - graph_only_radio = gr.Radio(choices=[True, False], value=True, label="Graph-only Answer") - graph_vector_radio = gr.Radio(choices=[True, False], value=False, label="Graph-Vector Answer") + graph_only_radio = gr.Radio( + choices=[True, False], value=True, label="Graph-only Answer" + ) + graph_vector_radio = gr.Radio( + choices=[True, False], value=False, label="Graph-Vector Answer" + ) def toggle_slider(enable): return gr.update(interactive=enable) @@ -291,8 +305,12 @@ def toggle_slider(enable): value="reranker" if online_rerank else "bleu", label="Rerank method", ) - example_num = gr.Number(value=-1, label="Template Num (<0 means disable text2gql) ", precision=0) - graph_ratio = gr.Slider(0, 1, 0.6, label="Graph Ratio", step=0.1, interactive=False) + example_num = gr.Number( + value=-1, label="Template Num (<0 means disable text2gql) ", precision=0 + ) + graph_ratio = gr.Slider( + 0, 1, 0.6, label="Graph Ratio", step=0.1, interactive=False + ) graph_vector_radio.change( toggle_slider, inputs=graph_vector_radio, outputs=graph_ratio @@ -325,8 +343,8 @@ def toggle_slider(enable): example_num, ], outputs=[raw_out, vector_only_out, graph_only_out, graph_vector_out], - queue=True, # Enable queueing for this event - concurrency_limit=5, # Maximum of 5 concurrent executions + queue=True, # Enable queueing for this event + concurrency_limit=5, # Maximum of 5 concurrent executions ) gr.Markdown( @@ -394,18 +412,20 @@ def several_rag_answer( total_rows = len(df) for index, row in df.iterrows(): question = row.iloc[0] - basic_llm_answer, vector_only_answer, graph_only_answer, graph_vector_answer = rag_answer( - question, - is_raw_answer, - is_vector_only_answer, - is_graph_only_answer, - is_graph_vector_answer, - graph_ratio_ui, - rerank_method_ui, - near_neighbor_first_ui, - custom_related_information_ui, - answer_prompt, - keywords_extract_prompt, + basic_llm_answer, vector_only_answer, graph_only_answer, graph_vector_answer = ( + rag_answer( + question, + is_raw_answer, + is_vector_only_answer, + is_graph_only_answer, + is_graph_vector_answer, + graph_ratio_ui, + rerank_method_ui, + near_neighbor_first_ui, + custom_related_information_ui, + answer_prompt, + keywords_extract_prompt, + ) ) df.at[index, "Basic LLM Answer"] = basic_llm_answer df.at[index, "Vector-only Answer"] = vector_only_answer @@ -418,7 +438,9 @@ def several_rag_answer( with gr.Row(): with gr.Column(): - questions_file = gr.File(file_types=[".xlsx", ".csv"], label="Questions File (.xlsx & csv)") + questions_file = gr.File( + file_types=[".xlsx", ".csv"], label="Questions File (.xlsx & csv)" + ) with gr.Column(): test_template_file = os.path.join(resource_path, "demo", "questions_template.xlsx") gr.File(value=test_template_file, label="Download Template File") diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py index 7d682403f..6600d7c41 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py @@ -33,11 +33,13 @@ from hugegraph_llm.utils.embedding_utils import get_index_folder_name from hugegraph_llm.utils.hugegraph_utils import run_gremlin_query from hugegraph_llm.utils.log import log +from hugegraph_llm.flows.scheduler import SchedulerSingleton @dataclass class GremlinResult: """Standardized result class for gremlin_generate function""" + success: bool match_result: str template_gremlin: Optional[str] = None @@ -47,13 +49,19 @@ class GremlinResult: error_message: Optional[str] = None @classmethod - def error(cls, message: str) -> 'GremlinResult': + def error(cls, message: str) -> "GremlinResult": """Create an error result""" return cls(success=False, match_result=message, error_message=message) @classmethod - def success_result(cls, match_result: str, template_gremlin: str, - raw_gremlin: str, template_exec: str, raw_exec: str) -> 'GremlinResult': + def success_result( + cls, + match_result: str, + template_gremlin: str, + raw_gremlin: str, + template_exec: str, + raw_exec: str, + ) -> "GremlinResult": """Create a successful result""" return cls( success=True, @@ -61,7 +69,7 @@ def success_result(cls, match_result: str, template_gremlin: str, template_gremlin=template_gremlin, raw_gremlin=raw_gremlin, template_exec_result=template_exec, - raw_exec_result=raw_exec + raw_exec_result=raw_exec, ) @@ -93,6 +101,7 @@ def build_example_vector_index(temp_file) -> dict: target_file = os.path.join(resource_path, folder_name, "gremlin_examples", file_name) try: import shutil + shutil.copy2(full_path, target_file) log.info("Successfully copied file to: %s", target_file) except (OSError, IOError) as e: @@ -143,7 +152,7 @@ def _configure_output_types(requested_outputs): "template_gremlin": True, "raw_gremlin": True, "template_execution_result": True, - "raw_execution_result": True + "raw_execution_result": True, } if requested_outputs: for key in output_types: @@ -176,7 +185,9 @@ def _execute_queries(context, output_types): def gremlin_generate( inp, example_num, schema, gremlin_prompt, requested_outputs: Optional[List[str]] = None ) -> GremlinResult: - generator = GremlinGenerator(llm=LLMs().get_text2gql_llm(), embedding=Embeddings().get_embedding()) + generator = GremlinGenerator( + llm=LLMs().get_text2gql_llm(), embedding=Embeddings().get_embedding() + ) sm = SchemaManager(graph_name=schema) processed_schema, short_schema = _process_schema(schema, generator, sm) @@ -196,7 +207,9 @@ def gremlin_generate( _execute_queries(context, output_types) - match_result = json.dumps(context.get("match_result", "No Results"), ensure_ascii=False, indent=2) + match_result = json.dumps( + context.get("match_result", "No Results"), ensure_ascii=False, indent=2 + ) return GremlinResult.success_result( match_result=match_result, template_gremlin=context["result"], @@ -220,7 +233,11 @@ def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]: if "edgelabels" in schema: mini_schema["edgelabels"] = [] for edge in schema["edgelabels"]: - new_edge = {key: edge[key] for key in ["name", "source_label", "target_label", "properties"] if key in edge} + new_edge = { + key: edge[key] + for key in ["name", "source_label", "target_label", "properties"] + if key in edge + } mini_schema["edgelabels"].append(new_edge) return mini_schema @@ -228,17 +245,40 @@ def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]: def gremlin_generate_for_ui(inp, example_num, schema, gremlin_prompt): """UI wrapper for gremlin_generate that returns tuple for Gradio compatibility""" - result = gremlin_generate(inp, example_num, schema, gremlin_prompt) - - if not result.success: - return result.match_result, "", "", "", "" + # Execute via scheduler + try: + res = SchedulerSingleton.get_instance().schedule_flow( + "text2gremlin", + inp, + int(example_num) if isinstance(example_num, (int, float, str)) else 2, + schema, + gremlin_prompt, + [ + "match_result", + "template_gremlin", + "raw_gremlin", + "template_execution_result", + "raw_execution_result", + ], + ) + except Exception as e: # pylint: disable=broad-except + log.error("UI text2gremlin error: %s", e) + return json.dumps({"error": str(e)}, ensure_ascii=False), "", "", "", "" + + # Backward-compatible mapping for outputs + match_result = res.get("match_result", []) + match_result_str = ( + json.dumps(match_result, ensure_ascii=False, indent=2) + if isinstance(match_result, (list, dict)) + else str(match_result) + ) return ( - result.match_result, - result.template_gremlin or "", - result.raw_gremlin or "", - result.template_exec_result or "", - result.raw_exec_result or "" + match_result_str, + res.get("template_gremlin", "") or "", + res.get("raw_gremlin", "") or "", + res.get("template_execution_result", "") or "", + res.get("raw_execution_result", "") or "", ) @@ -253,7 +293,8 @@ def create_text2gremlin_block() -> Tuple: ) with gr.Row(): file = gr.File( - value=os.path.join(resource_path, "demo", "text2gremlin.csv"), label="Upload Text-Gremlin Pairs File" + value=os.path.join(resource_path, "demo", "text2gremlin.csv"), + label="Upload Text-Gremlin Pairs File", ) out = gr.Textbox(label="Result Message") with gr.Row(): @@ -263,22 +304,39 @@ def create_text2gremlin_block() -> Tuple: with gr.Row(): with gr.Column(scale=1): - input_box = gr.Textbox(value=prompt.default_question, label="Nature Language Query", show_copy_button=True) - match = gr.Code(label="Similar Template (TopN)", language="javascript", elem_classes="code-container-show") + input_box = gr.Textbox( + value=prompt.default_question, label="Nature Language Query", show_copy_button=True + ) + match = gr.Code( + label="Similar Template (TopN)", + language="javascript", + elem_classes="code-container-show", + ) initialized_out = gr.Textbox(label="Gremlin With Template", show_copy_button=True) raw_out = gr.Textbox(label="Gremlin Without Template", show_copy_button=True) tmpl_exec_out = gr.Code( - label="Query With Template Output", language="json", elem_classes="code-container-show" + label="Query With Template Output", + language="json", + elem_classes="code-container-show", ) raw_exec_out = gr.Code( - label="Query Without Template Output", language="json", elem_classes="code-container-show" + label="Query Without Template Output", + language="json", + elem_classes="code-container-show", ) with gr.Column(scale=1): - example_num_slider = gr.Slider(minimum=0, maximum=10, step=1, value=2, label="Number of refer examples") - schema_box = gr.Textbox(value=prompt.text2gql_graph_schema, label="Schema", lines=2, show_copy_button=True) + example_num_slider = gr.Slider( + minimum=0, maximum=10, step=1, value=2, label="Number of refer examples" + ) + schema_box = gr.Textbox( + value=prompt.text2gql_graph_schema, label="Schema", lines=2, show_copy_button=True + ) prompt_box = gr.Textbox( - value=prompt.gremlin_generate_prompt, label="Prompt", lines=20, show_copy_button=True + value=prompt.gremlin_generate_prompt, + label="Prompt", + lines=20, + show_copy_button=True, ) btn = gr.Button("Text2Gremlin", variant="primary") btn.click( # pylint: disable=no-member @@ -324,6 +382,7 @@ def graph_rag_recall( context = rag.run(verbose=True, query=query, graph_search=True) return context + def gremlin_generate_selective( inp: str, example_num: int, diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py index 4aa476942..56b5de4b3 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py @@ -63,16 +63,12 @@ def generate_prompt_for_ui(source_text, scenario, example_name): Handles the UI logic for generating a new prompt using the new workflow architecture. """ if not all([source_text, scenario, example_name]): - gr.Warning( - "Please provide original text, expected scenario, and select an example!" - ) + gr.Warning("Please provide original text, expected scenario, and select an example!") return gr.update() try: # using new architecture scheduler = SchedulerSingleton.get_instance() - result = scheduler.schedule_flow( - "prompt_generate", source_text, scenario, example_name - ) + result = scheduler.schedule_flow("prompt_generate", source_text, scenario, example_name) gr.Info("Prompt generated successfully!") return result except Exception as e: @@ -83,9 +79,7 @@ def generate_prompt_for_ui(source_text, scenario, example_name): def load_example_names(): """Load all candidate examples""" try: - examples_path = os.path.join( - resource_path, "prompt_examples", "prompt_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "prompt_examples.json") with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return [example.get("name", "Unnamed example") for example in examples] @@ -99,27 +93,23 @@ def load_query_examples(): language = getattr( prompt, "language", - getattr(prompt.llm_settings, "language", "EN") - if hasattr(prompt, "llm_settings") - else "EN", + ( + getattr(prompt.llm_settings, "language", "EN") + if hasattr(prompt, "llm_settings") + else "EN" + ), ) if language.upper() == "CN": - examples_path = os.path.join( - resource_path, "prompt_examples", "query_examples_CN.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "query_examples_CN.json") else: - examples_path = os.path.join( - resource_path, "prompt_examples", "query_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "query_examples.json") with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) except (FileNotFoundError, json.JSONDecodeError): try: - examples_path = os.path.join( - resource_path, "prompt_examples", "query_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "query_examples.json") with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) @@ -130,9 +120,7 @@ def load_query_examples(): def load_schema_fewshot_examples(): """Load few-shot examples from a JSON file""" try: - examples_path = os.path.join( - resource_path, "prompt_examples", "schema_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "schema_examples.json") with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) @@ -143,14 +131,10 @@ def load_schema_fewshot_examples(): def update_example_preview(example_name): """Update the display content based on the selected example name.""" try: - examples_path = os.path.join( - resource_path, "prompt_examples", "prompt_examples.json" - ) + examples_path = os.path.join(resource_path, "prompt_examples", "prompt_examples.json") with open(examples_path, "r", encoding="utf-8") as f: all_examples = json.load(f) - selected_example = next( - (ex for ex in all_examples if ex.get("name") == example_name), None - ) + selected_example = next((ex for ex in all_examples if ex.get("name") == example_name), None) if selected_example: return ( @@ -178,9 +162,11 @@ def _create_prompt_helper_block(demo, input_text, info_extract_template): few_shot_dropdown = gr.Dropdown( choices=example_names, label="Select a Few-shot example as a reference", - value=example_names[0] - if example_names and example_names[0] != "No available examples" - else None, + value=( + example_names[0] + if example_names and example_names[0] != "No available examples" + else None + ), ) with gr.Accordion("View example details", open=False): example_desc_preview = gr.Markdown(label="Example description") @@ -193,9 +179,7 @@ def _create_prompt_helper_block(demo, input_text, info_extract_template): interactive=False, ) - generate_prompt_btn = gr.Button( - "🚀 Auto-generate Graph Extract Prompt", variant="primary" - ) + generate_prompt_btn = gr.Button("🚀 Auto-generate Graph Extract Prompt", variant="primary") # Bind the change event of the dropdown menu few_shot_dropdown.change( fn=update_example_preview, @@ -287,9 +271,7 @@ def create_vector_graph_block(): lines=15, max_lines=29, ) - out = gr.Code( - label="Output Info", language="json", elem_classes="code-container-edit" - ) + out = gr.Code(label="Output Info", language="json", elem_classes="code-container-edit") with gr.Row(): with gr.Accordion("Get RAG Info", open=False): @@ -298,12 +280,8 @@ def create_vector_graph_block(): graph_index_btn0 = gr.Button("Get Graph Index Info", size="sm") with gr.Accordion("Clear RAG Data", open=False): with gr.Column(): - vector_index_btn1 = gr.Button( - "Clear Chunks Vector Index", size="sm" - ) - graph_index_btn1 = gr.Button( - "Clear Graph Vid Vector Index", size="sm" - ) + vector_index_btn1 = gr.Button("Clear Chunks Vector Index", size="sm") + graph_index_btn1 = gr.Button("Clear Graph Vid Vector Index", size="sm") graph_data_btn0 = gr.Button("Clear Graph Data", size="sm") vector_import_bt = gr.Button("Import into Vector", variant="primary") @@ -376,9 +354,9 @@ def create_vector_graph_block(): inputs=[input_text, input_schema, info_extract_template], ) - graph_loading_bt.click( - import_graph_data, inputs=[out, input_schema], outputs=[out] - ).then(update_vid_embedding).then( + graph_loading_bt.click(import_graph_data, inputs=[out, input_schema], outputs=[out]).then( + update_vid_embedding + ).then( store_prompt, inputs=[input_text, input_schema, info_extract_template], ) diff --git a/hugegraph-llm/src/hugegraph_llm/document/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/document/chunk_split.py index 495ef667c..ee173b284 100644 --- a/hugegraph-llm/src/hugegraph_llm/document/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/document/chunk_split.py @@ -22,9 +22,9 @@ class ChunkSplitter: def __init__( - self, - split_type: Literal["paragraph", "sentence"] = "paragraph", - language: Literal["zh", "en"] = "zh" + self, + split_type: Literal["paragraph", "sentence"] = "paragraph", + language: Literal["zh", "en"] = "zh", ): if language == "zh": separators = ["\n\n", "\n", "。", ",", ""] @@ -34,15 +34,11 @@ def __init__( raise ValueError("Argument `language` must be zh or en!") if split_type == "paragraph": self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=500, - chunk_overlap=30, - separators=separators + chunk_size=500, chunk_overlap=30, separators=separators ) elif split_type == "sentence": self.text_splitter = RecursiveCharacterTextSplitter( - chunk_size=50, - chunk_overlap=0, - separators=separators + chunk_size=50, chunk_overlap=0, separators=separators ) else: raise ValueError("Arg `type` must be paragraph, sentence!") diff --git a/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py index fa10d0199..7d2735352 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py @@ -48,9 +48,7 @@ def build_flow(self, *args, **kwargs): def post_deal(self, pipeline=None): graph_summary_info = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - folder_name = get_index_folder_name( - huge_settings.graph_name, huge_settings.graph_space - ) + folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) filename_prefix = get_filename_prefix( llm_settings.embedding_type, diff --git a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py index 1b0c98253..55f53b7ad 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py @@ -27,9 +27,7 @@ class GraphExtractFlow(BaseFlow): def __init__(self): pass - def prepare( - self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type - ): + def prepare(self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type): # prepare input data prepared_input.texts = texts prepared_input.language = "zh" diff --git a/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py index 5581ef107..0b29b4e64 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py @@ -35,9 +35,11 @@ def prepare(self, prepared_input: WkFlowInput, data, schema): raise ValueError(f"Invalid JSON for 'data': {e.msg}") from e log.debug( "Import graph data (truncated): %s", - (data[:512] + "...") - if isinstance(data, str) and len(data) > 512 - else (data if isinstance(data, str) else ""), + ( + (data[:512] + "...") + if isinstance(data, str) and len(data) > 512 + else (data if isinstance(data, str) else "") + ), ) prepared_input.data_json = data_json prepared_input.schema = schema diff --git a/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py index aece6bd61..b4a7bf329 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py @@ -58,6 +58,4 @@ def post_deal(self, pipeline=None): Process the execution result of PromptGenerate workflow """ res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - return res.get( - "generated_extract_prompt", "Generation failed. Please check the logs." - ) + return res.get("generated_extract_prompt", "Generation failed. Please check the logs.") diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py index 559540ce3..3aedbe7f2 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -25,6 +25,7 @@ from hugegraph_llm.flows.build_schema import BuildSchemaFlow from hugegraph_llm.flows.prompt_generate import PromptGenerateFlow from hugegraph_llm.utils.log import log +from hugegraph_llm.flows.text2gremlin import Text2GremlinFlow class Scheduler: @@ -62,6 +63,10 @@ def __init__(self, max_pipeline: int = 10): "manager": GPipelineManager(), "flow": PromptGenerateFlow(), } + self.pipeline_pool["text2gremlin"] = { + "manager": GPipelineManager(), + "flow": Text2GremlinFlow(), + } self.max_pipeline = max_pipeline # TODO: Implement Agentic Workflow @@ -96,7 +101,9 @@ def schedule_flow(self, flow: str, *args, **kwargs): flow.prepare(prepared_input, *args, **kwargs) status = pipeline.run() if status.isErr(): - raise RuntimeError(f"Error in flow execution {status.getInfo()}") + error_msg = f"Error in flow execution {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) res = flow.post_deal(pipeline) manager.release(pipeline) return res diff --git a/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py b/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py new file mode 100644 index 000000000..e9ba4276c --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.index_node.gremlin_example_index_query import GremlinExampleIndexQueryNode +from hugegraph_llm.nodes.llm_node.text2gremlin import Text2GremlinNode +from hugegraph_llm.nodes.hugegraph_node.gremlin_execute import GremlinExecuteNode + +from typing import Any, Dict, List, Optional + + +class Text2GremlinFlow(BaseFlow): + def __init__(self): + pass + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + example_num: int, + schema_input: str, + gremlin_prompt_input: Optional[str], + requested_outputs: Optional[List[str]], + ): + # sanitize example_num to [0,10], fallback to 2 if invalid + if not isinstance(example_num, int): + example_num = 2 + example_num = max(0, min(10, example_num)) + + # filter requested_outputs to allowed set and cap to 5 + allowed = { + "match_result", + "template_gremlin", + "raw_gremlin", + "template_execution_result", + "raw_execution_result", + } + req = requested_outputs or ["template_gremlin"] + req = [x for x in req if x in allowed] + if not req: + req = ["template_gremlin"] + if len(req) > 5: + req = req[:5] + + prepared_input.query = query + prepared_input.example_num = example_num + prepared_input.schema = schema_input + prepared_input.gremlin_prompt = gremlin_prompt_input + prepared_input.requested_outputs = req + return + + def build_flow( + self, + query: str, + example_num: int, + schema_input: str, + gremlin_prompt_input: Optional[str] = None, + requested_outputs: Optional[List[str]] = None, + ): + pipeline = GPipeline() + + prepared_input = WkFlowInput() + self.prepare( + prepared_input, + query=query, + example_num=example_num, + schema_input=schema_input, + gremlin_prompt_input=gremlin_prompt_input, + requested_outputs=requested_outputs, + ) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + schema_node = SchemaNode() + ieq_node = GremlinExampleIndexQueryNode() + tgn_node = Text2GremlinNode() + exe_node = GremlinExecuteNode() + + pipeline.registerGElement(schema_node, set(), "schema_node") + pipeline.registerGElement(ieq_node, set(), "gremlin_example_index_query") + pipeline.registerGElement(tgn_node, {schema_node, ieq_node}, "text2gremlin") + pipeline.registerGElement(exe_node, {tgn_node}, "gremlin_execute") + + return pipeline + + def post_deal(self, pipeline=None) -> Dict[str, Any]: + state = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + # 始终返回 5 个标准键,避免前端因过滤异常看不到字段 + return { + "match_result": state.get("match_result", []), + "template_gremlin": state.get("result", ""), + "raw_gremlin": state.get("raw_result", ""), + "template_execution_result": state.get("template_exec_res", ""), + "raw_execution_result": state.get("raw_exec_res", ""), + } diff --git a/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py b/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py index e78aa6d58..694ca014d 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/graph_index.py @@ -24,15 +24,16 @@ class GraphIndex: def __init__( - self, - graph_url: Optional[str] = huge_settings.graph_url, - graph_name: Optional[str] = huge_settings.graph_name, - graph_user: Optional[str] = huge_settings.graph_user, - graph_pwd: Optional[str] = huge_settings.graph_pwd, - graph_space: Optional[str] = huge_settings.graph_space, + self, + graph_url: Optional[str] = huge_settings.graph_url, + graph_name: Optional[str] = huge_settings.graph_name, + graph_user: Optional[str] = huge_settings.graph_user, + graph_pwd: Optional[str] = huge_settings.graph_pwd, + graph_space: Optional[str] = huge_settings.graph_space, ): - self.client = PyHugeClient(url=graph_url, graph=graph_name, user=graph_user, pwd=graph_pwd, - graphspace=graph_space) + self.client = PyHugeClient( + url=graph_url, graph=graph_name, user=graph_user, pwd=graph_pwd, graphspace=graph_space + ) def clear_graph(self): self.client.gremlin().exec("g.V().drop()") diff --git a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py index 641ac6d6e..f85483185 100644 --- a/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/indices/vector_index.py @@ -37,7 +37,9 @@ def __init__(self, embed_dim: int = 1024): self.properties = [] @staticmethod - def from_index_file(dir_path: str, filename_prefix: str = None, record_miss: bool = True) -> "VectorIndex": + def from_index_file( + dir_path: str, filename_prefix: str = None, record_miss: bool = True + ) -> "VectorIndex": """Load index from files, supporting model-specific filenames. This method loads a Faiss index and its corresponding properties from a directory. @@ -47,13 +49,18 @@ def from_index_file(dir_path: str, filename_prefix: str = None, record_miss: boo matches the number of properties. """ index_name = f"{filename_prefix}_{INDEX_FILE_NAME}" if filename_prefix else INDEX_FILE_NAME - property_name = f"{filename_prefix}_{PROPERTIES_FILE_NAME}" if filename_prefix else PROPERTIES_FILE_NAME + property_name = ( + f"{filename_prefix}_{PROPERTIES_FILE_NAME}" if filename_prefix else PROPERTIES_FILE_NAME + ) index_file = os.path.join(dir_path, index_name) properties_file = os.path.join(dir_path, property_name) miss_files = [f for f in [index_file, properties_file] if not os.path.exists(f)] if miss_files: if record_miss: - log.warning("Missing vector files: %s. \nNeed create a new one for it.", ", ".join(miss_files)) + log.warning( + "Missing vector files: %s. \nNeed create a new one for it.", + ", ".join(miss_files), + ) return VectorIndex() try: @@ -61,7 +68,9 @@ def from_index_file(dir_path: str, filename_prefix: str = None, record_miss: boo with open(properties_file, "rb") as f: properties = pkl.load(f) except (RuntimeError, pkl.UnpicklingError, OSError) as e: - log.error("Failed to load index files for model '%s': %s", filename_prefix or "default", e) + log.error( + "Failed to load index files for model '%s': %s", filename_prefix or "default", e + ) raise RuntimeError( f"Could not load index files for model '{filename_prefix or 'default'}'. " f"Original error ({type(e).__name__}): {e}" @@ -85,7 +94,9 @@ def to_index_file(self, dir_path: str, filename_prefix: str = None): os.makedirs(dir_path) index_name = f"{filename_prefix}_{INDEX_FILE_NAME}" if filename_prefix else INDEX_FILE_NAME - property_name = f"{filename_prefix}_{PROPERTIES_FILE_NAME}" if filename_prefix else PROPERTIES_FILE_NAME + property_name = ( + f"{filename_prefix}_{PROPERTIES_FILE_NAME}" if filename_prefix else PROPERTIES_FILE_NAME + ) index_file = os.path.join(dir_path, index_name) properties_file = os.path.join(dir_path, property_name) faiss.write_index(self.index, index_file) @@ -115,7 +126,9 @@ def remove(self, props: Union[Set[Any], List[Any]]) -> int: self.properties = [p for i, p in enumerate(self.properties) if i not in indices] return remove_num - def search(self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9) -> List[Any]: + def search( + self, query_vector: List[float], top_k: int, dis_threshold: float = 0.9 + ) -> List[Any]: if self.index.ntotal == 0: return [] @@ -129,7 +142,9 @@ def search(self, query_vector: List[float], top_k: int, dis_threshold: float = 0 results.append(deepcopy(self.properties[i])) log.debug("[✓] Add valid distance %s to results.", dist) else: - log.debug("[x] Distance %s >= threshold %s, ignore this result.", dist, dis_threshold) + log.debug( + "[x] Distance %s >= threshold %s, ignore this result.", dist, dis_threshold + ) return results @staticmethod @@ -140,7 +155,9 @@ def clean(dir_path: str, filename_prefix: str = None): If model_name is None, it targets the default files. """ index_name = f"{filename_prefix}_{INDEX_FILE_NAME}" if filename_prefix else INDEX_FILE_NAME - property_name = f"{filename_prefix}_{PROPERTIES_FILE_NAME}" if filename_prefix else PROPERTIES_FILE_NAME + property_name = ( + f"{filename_prefix}_{PROPERTIES_FILE_NAME}" if filename_prefix else PROPERTIES_FILE_NAME + ) index_file = os.path.join(dir_path, index_name) properties_file = os.path.join(dir_path, property_name) diff --git a/hugegraph-llm/src/hugegraph_llm/middleware/middleware.py b/hugegraph-llm/src/hugegraph_llm/middleware/middleware.py index 47c70e1a4..c73242012 100644 --- a/hugegraph-llm/src/hugegraph_llm/middleware/middleware.py +++ b/hugegraph-llm/src/hugegraph_llm/middleware/middleware.py @@ -26,6 +26,7 @@ # TODO: we could use middleware(AOP) in the future (dig out the lifecycle of gradio & fastapi) class UseTimeMiddleware(BaseHTTPMiddleware): """Middleware to add process time to response headers""" + def __init__(self, app): super().__init__(app) @@ -33,7 +34,7 @@ async def dispatch(self, request: Request, call_next): # TODO: handle time record for async task pool in gradio start_time = time.perf_counter() response = await call_next(request) - process_time = (time.perf_counter() - start_time) * 1000 # ms + process_time = (time.perf_counter() - start_time) * 1000 # ms unit = "ms" if process_time > 1000: process_time /= 1000 @@ -46,6 +47,6 @@ async def dispatch(self, request: Request, call_next): request.method, request.query_params, request.client.host, - request.url + request.url, ) return response diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py index db9b2f105..698b92837 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/base.py @@ -32,9 +32,9 @@ class SimilarityMode(str, Enum): def similarity( - embedding1: Union[List[float], np.ndarray], - embedding2: Union[List[float], np.ndarray], - mode: SimilarityMode = SimilarityMode.DEFAULT, + embedding1: Union[List[float], np.ndarray], + embedding2: Union[List[float], np.ndarray], + mode: SimilarityMode = SimilarityMode.DEFAULT, ) -> float: """Get embedding similarity.""" if isinstance(embedding1, list): @@ -57,28 +57,22 @@ class BaseEmbedding(ABC): # TODO: replace all the usage by get_texts_embeddings() & remove it in the future @deprecated("Use get_texts_embeddings() instead in the future.") @abstractmethod - def get_text_embedding( - self, - text: str - ) -> List[float]: + def get_text_embedding(self, text: str) -> List[float]: """Comment""" @abstractmethod - def get_texts_embeddings( - self, - texts: List[str] - ) -> List[List[float]]: + def get_texts_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings for multiple texts in a single batch. - + This method should efficiently process multiple texts at once by leveraging the embedding model's batching capabilities, which is typically more efficient than processing texts individually. - + Parameters ---------- texts : List[str] A list of text strings to be embedded. - + Returns ------- List[List[float]] @@ -87,12 +81,9 @@ def get_texts_embeddings( """ @abstractmethod - async def async_get_texts_embeddings( - self, - texts: List[str] - ) -> List[List[float]]: + async def async_get_texts_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings for multiple texts in a single batch asynchronously. - + This method should efficiently process multiple texts at once by leveraging the embedding model's batching capabilities, which is typically more efficient than processing texts individually. @@ -101,7 +92,7 @@ async def async_get_texts_embeddings( ---------- texts : List[str] A list of text strings to be embedded. - + Returns ------- List[List[float]] @@ -111,9 +102,9 @@ async def async_get_texts_embeddings( @staticmethod def similarity( - embedding1: Union[List[float], np.ndarray], - embedding2: Union[List[float], np.ndarray], - mode: SimilarityMode = SimilarityMode.DEFAULT, + embedding1: Union[List[float], np.ndarray], + embedding2: Union[List[float], np.ndarray], + mode: SimilarityMode = SimilarityMode.DEFAULT, ) -> float: """Get embedding similarity.""" if isinstance(embedding1, list): diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py index f4026ad7f..d0e15f000 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/openai.py @@ -23,12 +23,12 @@ class OpenAIEmbedding: def __init__( - self, - model_name: str = "text-embedding-3-small", - api_key: Optional[str] = None, - api_base: Optional[str] = None + self, + model_name: str = "text-embedding-3-small", + api_key: Optional[str] = None, + api_base: Optional[str] = None, ): - api_key = api_key or '' + api_key = api_key or "" self.client = OpenAI(api_key=api_key, base_url=api_base) self.aclient = AsyncOpenAI(api_key=api_key, base_url=api_base) self.model_name = model_name @@ -38,21 +38,18 @@ def get_text_embedding(self, text: str) -> List[float]: response = self.client.embeddings.create(input=text, model=self.model_name) return response.data[0].embedding - def get_texts_embeddings( - self, - texts: List[str] - ) -> List[List[float]]: + def get_texts_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings for multiple texts in a single batch. - + This method efficiently processes multiple texts at once by leveraging OpenAI's batching capabilities, which is more efficient than processing texts individually. - + Parameters ---------- texts : List[str] A list of text strings to be embedded. - + Returns ------- List[List[float]] @@ -64,7 +61,7 @@ def get_texts_embeddings( async def async_get_texts_embeddings(self, texts: List[str]) -> List[List[float]]: """Get embeddings for multiple texts in a single batch asynchronously. - + This method should efficiently process multiple texts at once by leveraging the embedding model's batching capabilities, which is typically more efficient than processing texts individually. @@ -73,7 +70,7 @@ async def async_get_texts_embeddings(self, texts: List[str]) -> List[List[float] ---------- texts : List[str] A list of text strings to be embedded. - + Returns ------- List[List[float]] diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py index c6bfa44a8..69c082690 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/base.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/base.py @@ -24,48 +24,48 @@ class BaseLLM(ABC): @abstractmethod def generate( - self, - messages: Optional[List[Dict[str, Any]]] = None, - prompt: Optional[str] = None, + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, ) -> str: """Comment""" @abstractmethod async def agenerate( - self, - messages: Optional[List[Dict[str, Any]]] = None, - prompt: Optional[str] = None, + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, ) -> str: """Comment""" @abstractmethod def generate_streaming( - self, - messages: Optional[List[Dict[str, Any]]] = None, - prompt: Optional[str] = None, - on_token_callback: Optional[Callable] = None, + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + on_token_callback: Optional[Callable] = None, ) -> Generator[str, None, None]: """Comment""" @abstractmethod async def agenerate_streaming( - self, - messages: Optional[List[Dict[str, Any]]] = None, - prompt: Optional[str] = None, - on_token_callback: Optional[Callable] = None, + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, + on_token_callback: Optional[Callable] = None, ) -> AsyncGenerator[str, None]: """Comment""" @abstractmethod def num_tokens_from_string( - self, - string: str, + self, + string: str, ) -> str: """Given a string returns the number of tokens the given string consists of""" @abstractmethod def max_allowed_token_length( - self, + self, ) -> int: """Returns the maximum number of tokens the LLM can handle""" diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py index 7e1eaab68..9121fca09 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py @@ -173,8 +173,4 @@ def get_text2gql_llm(self): if __name__ == "__main__": client = LLMs().get_chat_llm() print(client.generate(prompt="What is the capital of China?")) - print( - client.generate( - messages=[{"role": "user", "content": "What is the capital of China?"}] - ) - ) + print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}])) diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py index b9cc0f19f..6f3c8129c 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/litellm.py @@ -51,7 +51,7 @@ def __init__( @retry( stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=2, max=5), - retry=retry_if_exception_type((RateLimitError, BudgetExceededError, APIError)) + retry=retry_if_exception_type((RateLimitError, BudgetExceededError, APIError)), ) def generate( self, @@ -80,12 +80,12 @@ def generate( @retry( stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=2, max=5), - retry=retry_if_exception_type((RateLimitError, BudgetExceededError, APIError)) + retry=retry_if_exception_type((RateLimitError, BudgetExceededError, APIError)), ) async def agenerate( - self, - messages: Optional[List[Dict[str, Any]]] = None, - prompt: Optional[str] = None, + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, ) -> str: """Generate a response to the query messages/prompt asynchronously.""" if messages is None: diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py index 5354ba306..6d08ce8cd 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py @@ -28,6 +28,7 @@ class OllamaClient(BaseLLM): """LLM wrapper should take in a prompt and return a string.""" + def __init__(self, model: str, host: str = "127.0.0.1", port: int = 11434, **kwargs): self.model = model self.client = ollama.Client(host=f"http://{host}:{port}", **kwargs) @@ -49,9 +50,9 @@ def generate( messages=messages, ) usage = { - "prompt_tokens": response['prompt_eval_count'], - "completion_tokens": response['eval_count'], - "total_tokens": response['prompt_eval_count'] + response['eval_count'], + "prompt_tokens": response["prompt_eval_count"], + "completion_tokens": response["eval_count"], + "total_tokens": response["prompt_eval_count"] + response["eval_count"], } log.info("Token usage: %s", json.dumps(usage)) return response["message"]["content"] @@ -61,9 +62,9 @@ def generate( @retry(tries=3, delay=1) async def agenerate( - self, - messages: Optional[List[Dict[str, Any]]] = None, - prompt: Optional[str] = None, + self, + messages: Optional[List[Dict[str, Any]]] = None, + prompt: Optional[str] = None, ) -> str: """Comment""" if messages is None: @@ -75,9 +76,9 @@ async def agenerate( messages=messages, ) usage = { - "prompt_tokens": response['prompt_eval_count'], - "completion_tokens": response['eval_count'], - "total_tokens": response['prompt_eval_count'] + response['eval_count'], + "prompt_tokens": response["prompt_eval_count"], + "completion_tokens": response["eval_count"], + "total_tokens": response["prompt_eval_count"] + response["eval_count"], } log.info("Token usage: %s", json.dumps(usage)) return response["message"]["content"] @@ -96,11 +97,7 @@ def generate_streaming( assert prompt is not None, "Messages or prompt must be provided." messages = [{"role": "user", "content": prompt}] - for chunk in self.client.chat( - model=self.model, - messages=messages, - stream=True - ): + for chunk in self.client.chat(model=self.model, messages=messages, stream=True): if not chunk["message"]: log.debug("Received empty chunk['message'] in streaming chunk: %s", chunk) continue @@ -122,9 +119,7 @@ async def agenerate_streaming( try: async_generator = await self.async_client.chat( - model=self.model, - messages=messages, - stream=True + model=self.model, messages=messages, stream=True ) async for chunk in async_generator: token = chunk.get("message", {}).get("content", "") @@ -135,7 +130,6 @@ async def agenerate_streaming( print(f"Retrying LLM call {e}") raise e - def num_tokens_from_string( self, string: str, diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py index 88cea3976..e1088c890 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/openai.py @@ -42,7 +42,7 @@ def __init__( max_tokens: int = 8092, temperature: float = 0.01, ) -> None: - api_key = api_key or '' + api_key = api_key or "" self.client = OpenAI(api_key=api_key, base_url=api_base) self.aclient = AsyncOpenAI(api_key=api_key, base_url=api_base) self.model = model_name @@ -186,7 +186,7 @@ async def agenerate_streaming( temperature=self.temperature, max_tokens=self.max_tokens, messages=messages, - stream=True + stream=True, ) async for chunk in completions: if not chunk.choices: diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py index 1710acfc2..3bf481ce2 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/cohere.py @@ -31,16 +31,21 @@ def __init__( self.base_url = base_url self.model = model - def get_rerank_lists(self, query: str, documents: List[str], top_n: Optional[int] = None) -> List[str]: + def get_rerank_lists( + self, query: str, documents: List[str], top_n: Optional[int] = None + ) -> List[str]: if not top_n: top_n = len(documents) - assert top_n <= len(documents), "'top_n' should be less than or equal to the number of documents" + assert top_n <= len( + documents + ), "'top_n' should be less than or equal to the number of documents" if top_n == 0: return [] url = self.base_url from pyhugegraph.utils.constants import Constants + headers = { "accept": Constants.HEADER_CONTENT_TYPE, "content-type": Constants.HEADER_CONTENT_TYPE, diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py index aa9f0c061..6136d61b4 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/init_reranker.py @@ -32,5 +32,7 @@ def get_reranker(self): model=llm_settings.reranker_model, ) if self.reranker_type == "siliconflow": - return SiliconReranker(api_key=llm_settings.reranker_api_key, model=llm_settings.reranker_model) + return SiliconReranker( + api_key=llm_settings.reranker_api_key, model=llm_settings.reranker_model + ) raise Exception("Reranker type is not supported!") diff --git a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py index d63b0ba3d..e4a9b550a 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py +++ b/hugegraph-llm/src/hugegraph_llm/models/rerankers/siliconflow.py @@ -29,10 +29,14 @@ def __init__( self.api_key = api_key self.model = model - def get_rerank_lists(self, query: str, documents: List[str], top_n: Optional[int] = None) -> List[str]: + def get_rerank_lists( + self, query: str, documents: List[str], top_n: Optional[int] = None + ) -> List[str]: if not top_n: top_n = len(documents) - assert top_n <= len(documents), "'top_n' should be less than or equal to the number of documents" + assert top_n <= len( + documents + ), "'top_n' should be less than or equal to the number of documents" if top_n == 0: return [] @@ -48,6 +52,7 @@ def get_rerank_lists(self, query: str, documents: List[str], top_n: Optional[int "top_n": top_n, } from pyhugegraph.utils.constants import Constants + headers = { "accept": Constants.HEADER_CONTENT_TYPE, "content-type": Constants.HEADER_CONTENT_TYPE, diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/gremlin_execute.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/gremlin_execute.py new file mode 100644 index 000000000..98fdcdd1b --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/gremlin_execute.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +from PyCGraph import CStatus + +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.utils.hugegraph_utils import run_gremlin_query + + +def _ensure_limit(query: str, default_limit: int = 100) -> str: + if not query: + return query + q_lower = query.lower() + if "limit(" in q_lower: + return query + if any(token in q_lower for token in ["g.v(", ".v(", "g.e(", ".e("]): + return f"{query}.limit({default_limit})" + return query + + +class GremlinExecuteNode(BaseNode): + def node_init(self): + return CStatus() + + def operator_schedule(self, data_json: Dict[str, Any]): + # Read requested outputs from wk_input + requested = getattr(self.wk_input, "requested_outputs", None) or [] + need_template = "template_execution_result" in requested + need_raw = "raw_execution_result" in requested + + tmpl_q = data_json.get("result", "") + raw_q = data_json.get("raw_result", "") + + if need_template: + try: + safe_q = _ensure_limit(tmpl_q) + data_json["template_exec_res"] = run_gremlin_query(query=safe_q) + except Exception as exc: # pylint: disable=broad-except + data_json["template_exec_res"] = f"{exc}" + else: + data_json["template_exec_res"] = "" + + if need_raw: + try: + safe_q = _ensure_limit(raw_q) + data_json["raw_exec_res"] = run_gremlin_query(query=safe_q) + except Exception as exc: # pylint: disable=broad-except + data_json["raw_exec_res"] = f"{exc}" + else: + data_json["raw_exec_res"] = "" + + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py index 71c490b20..84719d9eb 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py @@ -62,7 +62,7 @@ def node_init(self): return CStatus() def operator_schedule(self, data_json): - print(f"check data json {data_json}") + log.debug("SchemaNode input state: %s", data_json) if self.schema.startswith("{"): try: return self.check_schema.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py new file mode 100644 index 000000000..eb033d869 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Any, Dict + +from PyCGraph import CStatus + +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.gremlin_example_index_query import GremlinExampleIndexQuery +from hugegraph_llm.models.embeddings.init_embedding import Embeddings + + +class GremlinExampleIndexQueryNode(BaseNode): + operator: GremlinExampleIndexQuery + + def node_init(self): + # Build operator (index lazy-loading handled in operator) + embedding = Embeddings().get_embedding() + example_num = getattr(self.wk_input, "example_num", None) + if not isinstance(example_num, int): + example_num = 2 + # Clamp to [0, 10] + example_num = max(0, min(10, example_num)) + self.operator = GremlinExampleIndexQuery(embedding=embedding, num_examples=example_num) + return CStatus() + + def operator_schedule(self, data_json: Dict[str, Any]): + # Ensure query is present in context; degrade gracefully if empty + query = getattr(self.wk_input, "query", "") or "" + data_json["query"] = query + if not query: + data_json["match_result"] = [] + return data_json + # Operator.run writes match_result into context + return self.operator.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py index a28b41346..7df2e68e7 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py @@ -61,16 +61,12 @@ def node_init(self): # few_shot_schema: already parsed dict or raw JSON string few_shot_schema = {} - fss_src = ( - self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None - ) + fss_src = self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None if fss_src: try: few_shot_schema = json.loads(fss_src) except json.JSONDecodeError as e: - return CStatus( - -1, f"Few Shot Schema is not in a valid JSON format: {e}" - ) + return CStatus(-1, f"Few Shot Schema is not in a valid JSON format: {e}") _context_payload = { "raw_texts": raw_texts, diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py new file mode 100644 index 000000000..ffbafbaf4 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +from typing import Any, Dict, Optional + +from PyCGraph import CStatus + +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize +from hugegraph_llm.models.llms.init_llm import LLMs +from hugegraph_llm.config import prompt as prompt_cfg + + +def _stable_schema_string(state_json: Dict[str, Any]) -> str: + if "simple_schema" in state_json and state_json["simple_schema"] is not None: + return json.dumps(state_json["simple_schema"], ensure_ascii=False, sort_keys=True) + if "schema" in state_json and state_json["schema"] is not None: + return json.dumps(state_json["schema"], ensure_ascii=False, sort_keys=True) + return "" + + +class Text2GremlinNode(BaseNode): + operator: GremlinGenerateSynthesize + + def node_init(self): + # Select LLM + llm = LLMs().get_text2gql_llm() + # Serialize schema deterministically + state_json = self.context.to_json() + schema_str = _stable_schema_string(state_json) + # Prompt fallback + gremlin_prompt: Optional[str] = getattr(self.wk_input, "gremlin_prompt", None) + if gremlin_prompt is None or not str(gremlin_prompt).strip(): + gremlin_prompt = prompt_cfg.gremlin_generate_prompt + # Keep vertices/properties empty for now + self.operator = GremlinGenerateSynthesize( + llm=llm, + schema=schema_str, + vertices=None, + gremlin_prompt=gremlin_prompt, + ) + return CStatus() + + def operator_schedule(self, data_json: Dict[str, Any]): + # Ensure query exists in context; return empty if not provided + query = getattr(self.wk_input, "query", "") or "" + data_json["query"] = query + if not query: + data_json["result"] = "" + data_json["raw_result"] = "" + return data_json + # increase call count for observability + prev = data_json.get("call_count", 0) or 0 + data_json["call_count"] = prev + 1 + return self.operator.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py index c1c742032..fc729c11e 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/check_schema.py @@ -59,12 +59,8 @@ def _validate_schema(self, schema: Dict[str, Any]) -> None: check_type(schema, dict, "Input data is not a dictionary.") if "vertexlabels" not in schema or "edgelabels" not in schema: log_and_raise("Input data does not contain 'vertexlabels' or 'edgelabels'.") - check_type( - schema["vertexlabels"], list, "'vertexlabels' in input data is not a list." - ) - check_type( - schema["edgelabels"], list, "'edgelabels' in input data is not a list." - ) + check_type(schema["vertexlabels"], list, "'vertexlabels' in input data is not a list.") + check_type(schema["edgelabels"], list, "'edgelabels' in input data is not a list.") def _process_property_labels(self, schema: Dict[str, Any]) -> (list, set): property_labels = schema.get("propertykeys", []) @@ -82,19 +78,13 @@ def _process_vertex_labels( for vertex_label in schema["vertexlabels"]: self._validate_vertex_label(vertex_label) properties = vertex_label["properties"] - primary_keys = self._process_keys( - vertex_label, "primary_keys", properties[:1] - ) + primary_keys = self._process_keys(vertex_label, "primary_keys", properties[:1]) if len(primary_keys) == 0: log_and_raise(f"'primary_keys' of {vertex_label['name']} is empty.") vertex_label["primary_keys"] = primary_keys - nullable_keys = self._process_keys( - vertex_label, "nullable_keys", properties[1:] - ) + nullable_keys = self._process_keys(vertex_label, "nullable_keys", properties[1:]) vertex_label["nullable_keys"] = nullable_keys - self._add_missing_properties( - properties, property_labels, property_label_set - ) + self._add_missing_properties(properties, property_labels, property_label_set) def _process_edge_labels( self, schema: Dict[str, Any], property_labels: list, property_label_set: set @@ -102,17 +92,13 @@ def _process_edge_labels( for edge_label in schema["edgelabels"]: self._validate_edge_label(edge_label) properties = edge_label.get("properties", []) - self._add_missing_properties( - properties, property_labels, property_label_set - ) + self._add_missing_properties(properties, property_labels, property_label_set) def _validate_vertex_label(self, vertex_label: Dict[str, Any]) -> None: check_type(vertex_label, dict, "VertexLabel in input data is not a dictionary.") if "name" not in vertex_label: log_and_raise("VertexLabel in input data does not contain 'name'.") - check_type( - vertex_label["name"], str, "'name' in vertex_label is not of correct type." - ) + check_type(vertex_label["name"], str, "'name' in vertex_label is not of correct type.") if "properties" not in vertex_label: log_and_raise("VertexLabel in input data does not contain 'properties'.") check_type( @@ -133,9 +119,7 @@ def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: log_and_raise( "EdgeLabel in input data does not contain 'name', 'source_label', 'target_label'." ) - check_type( - edge_label["name"], str, "'name' in edge_label is not of correct type." - ) + check_type(edge_label["name"], str, "'name' in edge_label is not of correct type.") check_type( edge_label["source_label"], str, @@ -147,13 +131,9 @@ def _validate_edge_label(self, edge_label: Dict[str, Any]) -> None: "'target_label' in edge_label is not of correct type.", ) - def _process_keys( - self, label: Dict[str, Any], key_type: str, default_keys: list - ) -> list: + def _process_keys(self, label: Dict[str, Any], key_type: str, default_keys: list) -> list: keys = label.get(key_type, default_keys) - check_type( - keys, list, f"'{key_type}' in {label['name']} is not of correct type." - ) + check_type(keys, list, f"'{key_type}' in {label['name']} is not of correct type.") new_keys = [key for key in keys if key in label["properties"]] return new_keys diff --git a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py index 910de20d5..dc5b15e00 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py @@ -126,15 +126,20 @@ def _rerank_with_vertex_degree( reranker = Rerankers().get_reranker() try: vertex_rerank_res = [ - reranker.get_rerank_lists(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list + reranker.get_rerank_lists(query, vertex_degree) + [""] + for vertex_degree in vertex_degree_list ] except requests.exceptions.RequestException as e: - log.warning("Online reranker fails, automatically switches to local bleu method: %s", e) + log.warning( + "Online reranker fails, automatically switches to local bleu method: %s", e + ) self.method = "bleu" self.switch_to_bleu = True if self.method == "bleu": - vertex_rerank_res = [_bleu_rerank(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list] + vertex_rerank_res = [ + _bleu_rerank(query, vertex_degree) + [""] for vertex_degree in vertex_degree_list + ] depth = len(vertex_degree_list) for result in results: @@ -144,7 +149,9 @@ def _rerank_with_vertex_degree( knowledge_with_degree[result] += [""] * (depth - len(knowledge_with_degree[result])) def sort_key(res: str) -> Tuple[int, ...]: - return tuple(vertex_rerank_res[i].index(knowledge_with_degree[res][i]) for i in range(depth)) + return tuple( + vertex_rerank_res[i].index(knowledge_with_degree[res][i]) for i in range(depth) + ) sorted_results = sorted(results, key=sort_key) return sorted_results[:topn] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py index a873e19ad..6771a9aab 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py @@ -56,7 +56,8 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context["keywords"] = keywords from hugegraph_llm.utils.log import log - log.info("KEYWORDS: %s", context['keywords']) + + log.info("KEYWORDS: %s", context["keywords"]) return context def _filter_keywords( diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index be0ac0ca6..58848f827 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -184,7 +184,7 @@ def merge_dedup_rerank( method=rerank_method, near_neighbor_first=near_neighbor_first, custom_related_information=custom_related_information, - topk_return_results=topk_return_results + topk_return_results=topk_return_results, ) ) return self @@ -238,7 +238,7 @@ def run(self, **kwargs) -> Dict[str, Any]: """ if len(self._operators) == 0: self.extract_keywords().query_graphdb( - max_graph_items=kwargs.get('max_graph_items') + max_graph_items=kwargs.get("max_graph_items") ).synthesize_answer() context = kwargs diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py index 9eec04f7f..52626b72b 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py @@ -50,9 +50,7 @@ def run(self, data: dict) -> Dict[str, Any]: if not schema: # TODO: ensure the function works correctly (update the logic later) self.schema_free_mode(data.get("triples", [])) - log.warning( - "Using schema_free mode, could try schema_define mode for better effect!" - ) + log.warning("Using schema_free mode, could try schema_define mode for better effect!") else: self.init_schema_if_need(schema) self.load_into_graph(vertices, edges, schema) @@ -68,9 +66,7 @@ def _set_default_property(self, key, input_properties, property_label_map): # list or set default_value = [] input_properties[key] = default_value - log.warning( - "Property '%s' missing in vertex, set to '%s' for now", key, default_value - ) + log.warning("Property '%s' missing in vertex, set to '%s' for now", key, default_value) def _handle_graph_creation(self, func, *args, **kwargs): try: @@ -84,13 +80,9 @@ def _handle_graph_creation(self, func, *args, **kwargs): def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many-statements # pylint: disable=R0912 (too-many-branches) - vertex_label_map = { - v_label["name"]: v_label for v_label in schema["vertexlabels"] - } + vertex_label_map = {v_label["name"]: v_label for v_label in schema["vertexlabels"]} edge_label_map = {e_label["name"]: e_label for e_label in schema["edgelabels"]} - property_label_map = { - p_label["name"]: p_label for p_label in schema["propertykeys"] - } + property_label_map = {p_label["name"]: p_label for p_label in schema["propertykeys"]} for vertex in vertices: input_label = vertex["label"] @@ -106,9 +98,7 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- vertex_label = vertex_label_map[input_label] primary_keys = vertex_label["primary_keys"] nullable_keys = vertex_label.get("nullable_keys", []) - non_null_keys = [ - key for key in vertex_label["properties"] if key not in nullable_keys - ] + non_null_keys = [key for key in vertex_label["properties"] if key not in nullable_keys] has_problem = False # 2. Handle primary-keys mode vertex @@ -140,9 +130,7 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- # 3. Ensure all non-nullable props are set for key in non_null_keys: if key not in input_properties: - self._set_default_property( - key, input_properties, property_label_map - ) + self._set_default_property(key, input_properties, property_label_map) # 4. Check all data type value is right for key, value in input_properties.items(): @@ -179,9 +167,7 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- continue # TODO: we could try batch add edges first, setback to single-mode if failed - self._handle_graph_creation( - self.client.graph().addEdge, label, start, end, properties - ) + self._handle_graph_creation(self.client.graph().addEdge, label, start, end, properties) def init_schema_if_need(self, schema: dict): properties = schema["propertykeys"] @@ -205,20 +191,18 @@ def init_schema_if_need(self, schema: dict): source_vertex_label = edge["source_label"] target_vertex_label = edge["target_label"] properties = edge["properties"] - self.schema.edgeLabel(edge_label).sourceLabel( - source_vertex_label - ).targetLabel(target_vertex_label).properties(*properties).nullableKeys( - *properties - ).ifNotExist().create() + self.schema.edgeLabel(edge_label).sourceLabel(source_vertex_label).targetLabel( + target_vertex_label + ).properties(*properties).nullableKeys(*properties).ifNotExist().create() def schema_free_mode(self, data): self.schema.propertyKey("name").asText().ifNotExist().create() self.schema.vertexLabel("vertex").useCustomizeStringId().properties( "name" ).ifNotExist().create() - self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel( - "vertex" - ).properties("name").ifNotExist().create() + self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel("vertex").properties( + "name" + ).ifNotExist().create() self.schema.indexLabel("vertexByName").onV("vertex").by( "name" @@ -278,9 +262,7 @@ def _set_property_data_type(self, property_key, data_type): log.warning("UUID type is not supported, use text instead") property_key.asText() else: - log.error( - "Unknown data type %s for property_key %s", data_type, property_key - ) + log.error("Unknown data type %s for property_key %s", data_type, property_key) def _set_property_cardinality(self, property_key, cardinality): if cardinality == PropertyCardinality.SINGLE: @@ -290,13 +272,9 @@ def _set_property_cardinality(self, property_key, cardinality): elif cardinality == PropertyCardinality.SET: property_key.valueSet() else: - log.error( - "Unknown cardinality %s for property_key %s", cardinality, property_key - ) + log.error("Unknown cardinality %s for property_key %s", cardinality, property_key) - def _check_property_data_type( - self, data_type: str, cardinality: str, value - ) -> bool: + def _check_property_data_type(self, data_type: str, cardinality: str, value) -> bool: if cardinality in ( PropertyCardinality.LIST.value, PropertyCardinality.SET.value, @@ -326,9 +304,7 @@ def _check_single_data_type(self, data_type: str, value) -> bool: if data_type in (PropertyDataType.TEXT.value, PropertyDataType.UUID.value): return isinstance(value, str) # TODO: check ok below - if ( - data_type == PropertyDataType.DATE.value - ): # the format should be "yyyy-MM-dd" + if data_type == PropertyDataType.DATE.value: # the format should be "yyyy-MM-dd" import re return isinstance(value, str) and re.match(r"^\d{4}-\d{2}-\d{2}$", value) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py index 6012b7534..bcff5f07b 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py @@ -196,8 +196,8 @@ def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: log.debug("Kneighbor gremlin query: %s", gremlin_query) paths.extend(self._client.gremlin().exec(gremlin=gremlin_query)["data"]) - graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_graph_query_result( - query_paths=paths + graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = ( + self._format_graph_query_result(query_paths=paths) ) # TODO: we may need to optimize the logic here with global deduplication (may lack some single vertex) @@ -220,17 +220,21 @@ def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: max_deep=self._max_deep, max_items=self._max_items, ) - log.warning("Unable to find vid, downgraded to property query, please confirm if it meets expectation.") + log.warning( + "Unable to find vid, downgraded to property query, please confirm if it meets expectation." + ) paths: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)["data"] - graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = self._format_graph_query_result( - query_paths=paths + graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = ( + self._format_graph_query_result(query_paths=paths) ) context["graph_result"] = list(graph_chain_knowledge) if context["graph_result"]: context["graph_result_flag"] = 0 - context["vertex_degree_list"] = [list(vertex_degree) for vertex_degree in vertex_degree_list] + context["vertex_degree_list"] = [ + list(vertex_degree) for vertex_degree in vertex_degree_list + ] context["knowledge_with_degree"] = knowledge_with_degree context["graph_context_head"] = ( f"The following are graph knowledge in {self._max_deep} depth, e.g:\n" @@ -272,7 +276,9 @@ def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]: knowledge.add(node_str) return knowledge - def _format_graph_query_result(self, query_paths) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: + def _format_graph_query_result( + self, query_paths + ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: use_id_to_match = self._prop_to_match is None subgraph = set() subgraph_with_degree = {} @@ -282,7 +288,9 @@ def _format_graph_query_result(self, query_paths) -> Tuple[Set[str], List[Set[st for path in query_paths: # 1. Process each path - path_str, vertex_with_degree = self._process_path(path, use_id_to_match, v_cache, e_cache) + path_str, vertex_with_degree = self._process_path( + path, use_id_to_match, v_cache, e_cache + ) subgraph.add(path_str) subgraph_with_degree[path_str] = vertex_with_degree # 2. Update vertex degree list @@ -291,7 +299,11 @@ def _format_graph_query_result(self, query_paths) -> Tuple[Set[str], List[Set[st return subgraph, vertex_degree_list, subgraph_with_degree def _process_path( - self, path: Any, use_id_to_match: bool, v_cache: Set[str], e_cache: Set[Tuple[str, str, str]] + self, + path: Any, + use_id_to_match: bool, + v_cache: Set[str], + e_cache: Set[Tuple[str, str, str]], ) -> Tuple[str, List[str]]: flat_rel = "" raw_flat_rel = path["objects"] @@ -306,7 +318,14 @@ def _process_path( if i % 2 == 0: # Process each vertex flat_rel, prior_edge_str_len, depth = self._process_vertex( - item, flat_rel, node_cache, prior_edge_str_len, depth, nodes_with_degree, use_id_to_match, v_cache + item, + flat_rel, + node_cache, + prior_edge_str_len, + depth, + nodes_with_degree, + use_id_to_match, + v_cache, ) else: # Process each edge @@ -333,7 +352,9 @@ def _process_vertex( return flat_rel, prior_edge_str_len, depth node_cache.add(matched_str) - props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'v')}" for k, v in item["props"].items() if v) + props_str = ", ".join( + f"{k}: {self._limit_property_query(v, 'v')}" for k, v in item["props"].items() if v + ) # TODO: we may remove label id or replace with label name if matched_str in v_cache: @@ -356,10 +377,14 @@ def _process_edge( use_id_to_match: bool, e_cache: Set[Tuple[str, str, str]], ) -> Tuple[str, int]: - props_str = ", ".join(f"{k}: {self._limit_property_query(v, 'e')}" for k, v in item["props"].items() if v) + props_str = ", ".join( + f"{k}: {self._limit_property_query(v, 'e')}" for k, v in item["props"].items() if v + ) props_str = f"{{{props_str}}}" if props_str else "" prev_matched_str = ( - raw_flat_rel[i - 1]["id"] if use_id_to_match else (raw_flat_rel)[i - 1]["props"][self._prop_to_match] + raw_flat_rel[i - 1]["id"] + if use_id_to_match + else (raw_flat_rel)[i - 1]["props"][self._prop_to_match] ) edge_key = (item["inV"], item["label"], item["outV"]) @@ -369,12 +394,16 @@ def _process_edge( else: edge_label = item["label"] - edge_str = f"--[{edge_label}]-->" if item["outV"] == prev_matched_str else f"<--[{edge_label}]--" + edge_str = ( + f"--[{edge_label}]-->" if item["outV"] == prev_matched_str else f"<--[{edge_label}]--" + ) path_str += edge_str prior_edge_str_len = len(edge_str) return path_str, prior_edge_str_len - def _update_vertex_degree_list(self, vertex_degree_list: List[Set[str]], nodes_with_degree: List[str]) -> None: + def _update_vertex_degree_list( + self, vertex_degree_list: List[Set[str]], nodes_with_degree: List[str] + ) -> None: for depth, node_str in enumerate(nodes_with_degree): if depth >= len(vertex_degree_list): vertex_degree_list.append(set()) @@ -384,8 +413,8 @@ def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]: schema = self._get_graph_schema() vertex_props_str, edge_props_str = schema.split("\n")[:2] # TODO: rename to vertex (also need update in the schema) - vertex_props_str = vertex_props_str[len("Vertex properties: "):].strip("[").strip("]") - edge_props_str = edge_props_str[len("Edge properties: "):].strip("[").strip("]") + vertex_props_str = vertex_props_str[len("Vertex properties: ") :].strip("[").strip("]") + edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]") vertex_labels = self._extract_label_names(vertex_props_str) edge_labels = self._extract_label_names(edge_props_str) return vertex_labels, edge_labels diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py index c4e2124c3..90f1c00ea 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/schema_manager.py @@ -40,9 +40,7 @@ def simple_schema(self, schema: Dict[str, Any]) -> Dict[str, Any]: mini_schema["vertexlabels"] = [] for vertex in schema["vertexlabels"]: new_vertex = { - key: vertex[key] - for key in ["id", "name", "properties"] - if key in vertex + key: vertex[key] for key in ["id", "name", "properties"] if key in vertex } mini_schema["vertexlabels"].append(new_vertex) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py index 657baf68e..6d9f96214 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py @@ -23,17 +23,25 @@ from hugegraph_llm.config import resource_path, llm_settings, huge_settings from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel, get_filename_prefix, get_index_folder_name +from hugegraph_llm.utils.embedding_utils import ( + get_embeddings_parallel, + get_filename_prefix, + get_index_folder_name, +) # FIXME: we need keep the logic same with build_semantic_index.py class BuildGremlinExampleIndex: def __init__(self, embedding: BaseEmbedding, examples: List[Dict[str, str]]): - self.folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) self.index_dir = str(os.path.join(resource_path, self.folder_name, "gremlin_examples")) self.examples = examples self.embedding = embedding - self.filename_prefix = get_filename_prefix(llm_settings.embedding_type, getattr(embedding, "model_name", None)) + self.filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(embedding, "model_name", None) + ) def run(self, context: Dict[str, Any]) -> Dict[str, Any]: # !: We have assumed that self.example is not empty diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py index 4b7c4e3d4..5689a59ac 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py @@ -24,15 +24,23 @@ from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager -from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel, get_filename_prefix, get_index_folder_name +from hugegraph_llm.utils.embedding_utils import ( + get_embeddings_parallel, + get_filename_prefix, + get_index_folder_name, +) from hugegraph_llm.utils.log import log class BuildSemanticIndex: def __init__(self, embedding: BaseEmbedding): - self.folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) self.index_dir = str(os.path.join(resource_path, self.folder_name, "graph_vids")) - self.filename_prefix = get_filename_prefix(llm_settings.embedding_type, getattr(embedding, "model_name", None)) + self.filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(embedding, "model_name", None) + ) self.vid_index = VectorIndex.from_index_file(self.index_dir, self.filename_prefix) self.embedding = embedding self.sm = SchemaManager(huge_settings.graph_name) @@ -42,27 +50,19 @@ def _extract_names(self, vertices: list[str]) -> list[str]: def run(self, context: Dict[str, Any]) -> Dict[str, Any]: vertexlabels = self.sm.schema.getSchema()["vertexlabels"] - all_pk_flag = all( - data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels - ) + all_pk_flag = all(data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels) past_vids = self.vid_index.properties # TODO: We should build vid vector index separately, especially when the vertices may be very large - present_vids = context[ - "vertices" - ] # Warning: data truncated by fetch_graph_data.py + present_vids = context["vertices"] # Warning: data truncated by fetch_graph_data.py removed_vids = set(past_vids) - set(present_vids) removed_num = self.vid_index.remove(removed_vids) added_vids = list(set(present_vids) - set(past_vids)) if added_vids: - vids_to_process = ( - self._extract_names(added_vids) if all_pk_flag else added_vids - ) - added_embeddings = asyncio.run( - get_embeddings_parallel(self.embedding, vids_to_process) - ) + vids_to_process = self._extract_names(added_vids) if all_pk_flag else added_vids + added_embeddings = asyncio.run(get_embeddings_parallel(self.embedding, vids_to_process)) log.info("Building vector index for %s vertices...", len(added_vids)) self.vid_index.add(added_embeddings, added_vids) self.vid_index.to_index_file(self.index_dir, self.filename_prefix) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py index 5cdad0316..f5fb823c5 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_vector_index.py @@ -41,9 +41,7 @@ def __init__(self, embedding: BaseEmbedding): self.filename_prefix = get_filename_prefix( llm_settings.embedding_type, getattr(self.embedding, "model_name", None) ) - self.vector_index = VectorIndex.from_index_file( - self.index_dir, self.filename_prefix - ) + self.vector_index = VectorIndex.from_index_file(self.index_dir, self.filename_prefix) def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if "chunks" not in context: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py index 96d1a3833..b680f2ca3 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/gremlin_example_index_query.py @@ -26,7 +26,11 @@ from hugegraph_llm.indices.vector_index import VectorIndex, INDEX_FILE_NAME, PROPERTIES_FILE_NAME from hugegraph_llm.models.embeddings.base import BaseEmbedding from hugegraph_llm.models.embeddings.init_embedding import Embeddings -from hugegraph_llm.utils.embedding_utils import get_embeddings_parallel, get_filename_prefix, get_index_folder_name +from hugegraph_llm.utils.embedding_utils import ( + get_embeddings_parallel, + get_filename_prefix, + get_index_folder_name, +) from hugegraph_llm.utils.log import log @@ -34,16 +38,25 @@ class GremlinExampleIndexQuery: def __init__(self, embedding: BaseEmbedding = None, num_examples: int = 1): self.embedding = embedding or Embeddings().get_embedding() self.num_examples = num_examples - self.folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) self.index_dir = str(os.path.join(resource_path, self.folder_name, "gremlin_examples")) - self.filename_prefix = get_filename_prefix(llm_settings.embedding_type, - getattr(self.embedding, "model_name", None)) + self.filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(self.embedding, "model_name", None) + ) self._ensure_index_exists() self.vector_index = VectorIndex.from_index_file(self.index_dir, self.filename_prefix) def _ensure_index_exists(self): - index_name = f"{self.filename_prefix}_{INDEX_FILE_NAME}" if self.filename_prefix else INDEX_FILE_NAME - props_name = f"{self.filename_prefix}_{PROPERTIES_FILE_NAME}" if self.filename_prefix else PROPERTIES_FILE_NAME + index_name = ( + f"{self.filename_prefix}_{INDEX_FILE_NAME}" if self.filename_prefix else INDEX_FILE_NAME + ) + props_name = ( + f"{self.filename_prefix}_{PROPERTIES_FILE_NAME}" + if self.filename_prefix + else PROPERTIES_FILE_NAME + ) if not ( os.path.exists(os.path.join(self.index_dir, index_name)) and os.path.exists(os.path.join(self.index_dir, props_name)) @@ -61,7 +74,9 @@ def _get_match_result(self, context: Dict[str, Any], query: str) -> List[Dict[st return self.vector_index.search(query_embedding, self.num_examples, dis_threshold=1.8) def _build_default_example_index(self): - properties = pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")).to_dict(orient="records") + properties = pd.read_csv(os.path.join(resource_path, "demo", "text2gremlin.csv")).to_dict( + orient="records" + ) # TODO: reuse the logic in build_semantic_index.py (consider extract the batch-embedding method) queries = [row["query"] for row in properties] embeddings = asyncio.run(get_embeddings_parallel(self.embedding, queries)) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py index 8e195453d..3ac03246f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py @@ -31,16 +31,20 @@ class SemanticIdQuery: ID_QUERY_TEMPL = "g.V({vids_str}).limit(8)" def __init__( - self, - embedding: BaseEmbedding, - by: Literal["query", "keywords"] = "keywords", - topk_per_query: int = 10, - topk_per_keyword: int = huge_settings.topk_per_keyword, - vector_dis_threshold: float = huge_settings.vector_dis_threshold, + self, + embedding: BaseEmbedding, + by: Literal["query", "keywords"] = "keywords", + topk_per_query: int = 10, + topk_per_keyword: int = huge_settings.topk_per_keyword, + vector_dis_threshold: float = huge_settings.vector_dis_threshold, ): - self.folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) self.index_dir = str(os.path.join(resource_path, self.folder_name, "graph_vids")) - self.filename_prefix = get_filename_prefix(llm_settings.embedding_type, getattr(embedding, "model_name", None)) + self.filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(embedding, "model_name", None) + ) self.vector_index = VectorIndex.from_index_file(self.index_dir, self.filename_prefix) self.embedding = embedding self.by = by @@ -65,7 +69,7 @@ def _exact_match_vids(self, keywords: List[str]) -> Tuple[List[str], List[str]]: vids_str = ",".join([f"'{vid}'" for vid in possible_vids]) resp = self._client.gremlin().exec(SemanticIdQuery.ID_QUERY_TEMPL.format(vids_str=vids_str)) - searched_vids = [v['id'] for v in resp['data']] + searched_vids = [v["id"] for v in resp["data"]] unsearched_keywords = set(keywords) for vid in searched_vids: @@ -79,10 +83,13 @@ def _fuzzy_match_vids(self, keywords: List[str]) -> List[str]: fuzzy_match_result = [] for keyword in keywords: keyword_vector = self.embedding.get_texts_embeddings([keyword])[0] - results = self.vector_index.search(keyword_vector, top_k=self.topk_per_keyword, - dis_threshold=float(self.vector_dis_threshold)) + results = self.vector_index.search( + keyword_vector, + top_k=self.topk_per_keyword, + dis_threshold=float(self.vector_dis_threshold), + ) if results: - fuzzy_match_result.extend(results[:self.topk_per_keyword]) + fuzzy_match_result.extend(results[: self.topk_per_keyword]) return fuzzy_match_result def run(self, context: Dict[str, Any]) -> Dict[str, Any]: @@ -92,7 +99,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: query_vector = self.embedding.get_texts_embeddings([query])[0] results = self.vector_index.search(query_vector, top_k=self.topk_per_query) if results: - graph_query_list.update(results[:self.topk_per_query]) + graph_query_list.update(results[: self.topk_per_query]) else: # by keywords keywords = context.get("keywords", []) if not keywords: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py index e29f50a76..4ed616929 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/vector_index_query.py @@ -30,9 +30,13 @@ class VectorIndexQuery: def __init__(self, embedding: BaseEmbedding, topk: int = 3): self.embedding = embedding self.topk = topk - self.folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + self.folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) self.index_dir = str(os.path.join(resource_path, self.folder_name, "chunks")) - self.filename_prefix = get_filename_prefix(llm_settings.embedding_type, getattr(embedding, "model_name", None)) + self.filename_prefix = get_filename_prefix( + llm_settings.embedding_type, getattr(embedding, "model_name", None) + ) self.vector_index = VectorIndex.from_index_file(self.index_dir, self.filename_prefix) def run(self, context: Dict[str, Any]) -> Dict[str, Any]: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py index 4348477f6..3b5c63103 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py @@ -37,7 +37,12 @@ class KgBuilder: - def __init__(self, llm: BaseLLM, embedding: Optional[BaseEmbedding] = None, graph: Optional[PyHugeClient] = None): + def __init__( + self, + llm: BaseLLM, + embedding: Optional[BaseEmbedding] = None, + graph: Optional[PyHugeClient] = None, + ): self.operators = [] self.llm = llm self.embedding = embedding @@ -69,7 +74,9 @@ def chunk_split( return self def extract_info( - self, example_prompt: Optional[str] = None, extract_type: Literal["triples", "property_graph"] = "triples" + self, + example_prompt: Optional[str] = None, + extract_type: Literal["triples", "property_graph"] = "triples", ): if extract_type == "triples": self.operators.append(InfoExtract(self.llm, example_prompt)) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index 5c4ab5fd3..9138f9e9b 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -62,17 +62,26 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context_head_str, context_tail_str = self.init_llm(context) if self._context_body is not None: - context_str = (f"{context_head_str}\n" - f"{self._context_body}\n" - f"{context_tail_str}".strip("\n")) + context_str = ( + f"{context_head_str}\n" f"{self._context_body}\n" f"{context_tail_str}".strip("\n") + ) - final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + final_prompt = self._prompt_template.format( + context_str=context_str, query_str=self._question + ) response = self._llm.generate(prompt=final_prompt) return {"answer": response} graph_result_context, vector_result_context = self.handle_vector_graph(context) - context = asyncio.run(self.async_generate(context, context_head_str, context_tail_str, - vector_result_context, graph_result_context)) + context = asyncio.run( + self.async_generate( + context, + context_head_str, + context_tail_str, + vector_result_context, + graph_result_context, + ) + ) return context def init_llm(self, context): @@ -95,7 +104,9 @@ def handle_vector_graph(self, context): vector_result_context = "No (vector)phrase related to the query." graph_result = context.get("graph_result") if graph_result: - graph_context_head = context.get("graph_context_head", "Knowledge from graphdb for the query:\n") + graph_context_head = context.get( + "graph_context_head", "Knowledge from graphdb for the query:\n" + ) graph_result_context = graph_context_head + "\n".join( f"{i + 1}. {res}" for i, res in enumerate(graph_result) ) @@ -108,11 +119,13 @@ async def run_streaming(self, context: Dict[str, Any]) -> AsyncGenerator[Dict[st context_head_str, context_tail_str = self.init_llm(context) if self._context_body is not None: - context_str = (f"{context_head_str}\n" - f"{self._context_body}\n" - f"{context_tail_str}".strip("\n")) + context_str = ( + f"{context_head_str}\n" f"{self._context_body}\n" f"{context_tail_str}".strip("\n") + ) - final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + final_prompt = self._prompt_template.format( + context_str=context_str, query_str=self._question + ) response = self._llm.generate(prompt=final_prompt) yield {"answer": response} return @@ -120,45 +133,60 @@ async def run_streaming(self, context: Dict[str, Any]) -> AsyncGenerator[Dict[st graph_result_context, vector_result_context = self.handle_vector_graph(context) async for context in self.async_streaming_generate( - context, - context_head_str, - context_tail_str, - vector_result_context, - graph_result_context + context, context_head_str, context_tail_str, vector_result_context, graph_result_context ): yield context - async def async_generate(self, context: Dict[str, Any], context_head_str: str, - context_tail_str: str, vector_result_context: str, - graph_result_context: str): + async def async_generate( + self, + context: Dict[str, Any], + context_head_str: str, + context_tail_str: str, + vector_result_context: str, + graph_result_context: str, + ): # async_tasks stores the async tasks for different answer types async_tasks = {} if self._raw_answer: final_prompt = self._question async_tasks["raw_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) if self._vector_only_answer: - context_str = (f"{context_head_str}\n" - f"{vector_result_context}\n" - f"{context_tail_str}".strip("\n")) + context_str = ( + f"{context_head_str}\n" + f"{vector_result_context}\n" + f"{context_tail_str}".strip("\n") + ) - final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) - async_tasks["vector_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) + final_prompt = self._prompt_template.format( + context_str=context_str, query_str=self._question + ) + async_tasks["vector_only_task"] = asyncio.create_task( + self._llm.agenerate(prompt=final_prompt) + ) if self._graph_only_answer: - context_str = (f"{context_head_str}\n" - f"{graph_result_context}\n" - f"{context_tail_str}".strip("\n")) + context_str = ( + f"{context_head_str}\n" + f"{graph_result_context}\n" + f"{context_tail_str}".strip("\n") + ) - final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) - async_tasks["graph_only_task"] = asyncio.create_task(self._llm.agenerate(prompt=final_prompt)) + final_prompt = self._prompt_template.format( + context_str=context_str, query_str=self._question + ) + async_tasks["graph_only_task"] = asyncio.create_task( + self._llm.agenerate(prompt=final_prompt) + ) if self._graph_vector_answer: context_body_str = f"{vector_result_context}\n{graph_result_context}" if context.get("graph_ratio", 0.5) < 0.5: context_body_str = f"{graph_result_context}\n{vector_result_context}" - context_str = (f"{context_head_str}\n" - f"{context_body_str}\n" - f"{context_tail_str}".strip("\n")) + context_str = ( + f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}".strip("\n") + ) - final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + final_prompt = self._prompt_template.format( + context_str=context_str, query_str=self._question + ) async_tasks["graph_vector_task"] = asyncio.create_task( self._llm.agenerate(prompt=final_prompt) ) @@ -167,7 +195,7 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str, "raw_task": "raw_answer", "vector_only_task": "vector_only_answer", "graph_only_task": "graph_only_answer", - "graph_vector_task": "graph_vector_answer" + "graph_vector_task": "graph_vector_answer", } for task_key, context_key in async_tasks_mapping.items(): @@ -176,66 +204,95 @@ async def async_generate(self, context: Dict[str, Any], context_head_str: str, context[context_key] = response log.debug("Query Answer: %s", response) - ops = sum([self._raw_answer, self._vector_only_answer, self._graph_only_answer, self._graph_vector_answer]) - context['call_count'] = context.get('call_count', 0) + ops + ops = sum( + [ + self._raw_answer, + self._vector_only_answer, + self._graph_only_answer, + self._graph_vector_answer, + ] + ) + context["call_count"] = context.get("call_count", 0) + ops return context - async def async_streaming_generate(self, context: Dict[str, Any], context_head_str: str, - context_tail_str: str, vector_result_context: str, - graph_result_context: str) -> AsyncGenerator[Dict[str, Any], None]: + async def async_streaming_generate( + self, + context: Dict[str, Any], + context_head_str: str, + context_tail_str: str, + vector_result_context: str, + graph_result_context: str, + ) -> AsyncGenerator[Dict[str, Any], None]: # async_tasks stores the async tasks for different answer types async_generators = [] auto_id = 0 if self._raw_answer: final_prompt = self._question async_generators.append( - self.__llm_generate_with_meta_info(task_id=auto_id, target_key="raw_answer", prompt=final_prompt) + self.__llm_generate_with_meta_info( + task_id=auto_id, target_key="raw_answer", prompt=final_prompt + ) ) auto_id += 1 if self._vector_only_answer: - context_str = (f"{context_head_str}\n" - f"{vector_result_context}\n" - f"{context_tail_str}".strip("\n")) + context_str = ( + f"{context_head_str}\n" + f"{vector_result_context}\n" + f"{context_tail_str}".strip("\n") + ) - final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + final_prompt = self._prompt_template.format( + context_str=context_str, query_str=self._question + ) async_generators.append( self.__llm_generate_with_meta_info( - task_id=auto_id, - target_key="vector_only_answer", - prompt=final_prompt + task_id=auto_id, target_key="vector_only_answer", prompt=final_prompt ) ) auto_id += 1 if self._graph_only_answer: - context_str = (f"{context_head_str}\n" - f"{graph_result_context}\n" - f"{context_tail_str}".strip("\n")) + context_str = ( + f"{context_head_str}\n" + f"{graph_result_context}\n" + f"{context_tail_str}".strip("\n") + ) - final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + final_prompt = self._prompt_template.format( + context_str=context_str, query_str=self._question + ) async_generators.append( - self.__llm_generate_with_meta_info(task_id=auto_id, target_key="graph_only_answer", prompt=final_prompt) + self.__llm_generate_with_meta_info( + task_id=auto_id, target_key="graph_only_answer", prompt=final_prompt + ) ) auto_id += 1 if self._graph_vector_answer: context_body_str = f"{vector_result_context}\n{graph_result_context}" if context.get("graph_ratio", 0.5) < 0.5: context_body_str = f"{graph_result_context}\n{vector_result_context}" - context_str = (f"{context_head_str}\n" - f"{context_body_str}\n" - f"{context_tail_str}".strip("\n")) + context_str = ( + f"{context_head_str}\n" f"{context_body_str}\n" f"{context_tail_str}".strip("\n") + ) - final_prompt = self._prompt_template.format(context_str=context_str, query_str=self._question) + final_prompt = self._prompt_template.format( + context_str=context_str, query_str=self._question + ) async_generators.append( self.__llm_generate_with_meta_info( - task_id=auto_id, - target_key="graph_vector_answer", - prompt=final_prompt + task_id=auto_id, target_key="graph_vector_answer", prompt=final_prompt ) ) auto_id += 1 - ops = sum([self._raw_answer, self._vector_only_answer, self._graph_only_answer, self._graph_vector_answer]) - context['call_count'] = context.get('call_count', 0) + ops + ops = sum( + [ + self._raw_answer, + self._vector_only_answer, + self._graph_only_answer, + self._graph_vector_answer, + ] + ) + context["call_count"] = context.get("call_count", 0) + ops async_tasks = [asyncio.create_task(anext(gen)) for gen in async_generators] while True: diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py index 817065aa0..2ac2eafff 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/disambiguate_data.py @@ -53,7 +53,8 @@ def run(self, data: Dict) -> Dict[str, List[Any]]: extract_triples_by_regex(llm_output, data) print( f"LLM {self.__class__.__name__} input:{prompt} \n" - f" output: {llm_output} \n data: {data}") + f" output: {llm_output} \n data: {data}" + ) data["call_count"] = data.get("call_count", 0) + 1 return data diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py index 11f0f6022..650834300 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/gremlin_generate.py @@ -54,7 +54,8 @@ def _format_examples(self, examples: Optional[List[Dict[str, str]]]) -> Optional example_strings = [] for example in examples: example_strings.append( - f"- query: {example['query']}\n" f"- gremlin:\n```gremlin\n{example['gremlin']}\n```" + f"- query: {example['query']}\n" + f"- gremlin:\n```gremlin\n{example['gremlin']}\n```" ) return "\n\n".join(example_strings) @@ -89,11 +90,17 @@ async def async_generate(self, context: Dict[str, Any]): vertices=self._format_vertices(vertices=self.vertices), properties=self._format_properties(properties=None), ) - async_tasks["initialized_answer"] = asyncio.create_task(self.llm.agenerate(prompt=init_prompt)) + async_tasks["initialized_answer"] = asyncio.create_task( + self.llm.agenerate(prompt=init_prompt) + ) raw_response = await async_tasks["raw_answer"] initialized_response = await async_tasks["initialized_answer"] - log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", init_prompt, initialized_response) + log.debug( + "Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", + init_prompt, + initialized_response, + ) context["result"] = self._extract_response(response=initialized_response) context["raw_result"] = self._extract_response(response=raw_response) @@ -123,7 +130,11 @@ def sync_generate(self, context: Dict[str, Any]): ) initialized_response = self.llm.generate(prompt=init_prompt) - log.debug("Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", init_prompt, initialized_response) + log.debug( + "Text2Gremlin with tmpl prompt:\n %s,\n LLM Response: %s", + init_prompt, + initialized_response, + ) context["result"] = self._extract_response(response=initialized_response) context["raw_result"] = self._extract_response(response=raw_response) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py index 571ffde51..8897e0fea 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/info_extract.py @@ -198,12 +198,8 @@ def valid(self, element_id: str, max_length: int = 256) -> bool: return True def _filter_long_id(self, graph) -> Dict[str, List[Any]]: - graph["vertices"] = [ - vertex for vertex in graph["vertices"] if self.valid(vertex["id"]) - ] + graph["vertices"] = [vertex for vertex in graph["vertices"] if self.valid(vertex["id"])] graph["edges"] = [ - edge - for edge in graph["edges"] - if self.valid(edge["start"]) and self.valid(edge["end"]) + edge for edge in graph["edges"] if self.valid(edge["start"]) and self.valid(edge["end"]) ] return graph diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py index 1e9ca652b..32ed9651e 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py @@ -151,6 +151,7 @@ def _extract_keywords_from_response( response: str, lowercase: bool = True, start_token: str = "", +<<<<<<< HEAD ) -> Dict[str, float]: results = {} @@ -181,4 +182,27 @@ def _extract_keywords_from_response( except (ValueError, AttributeError) as e: log.warning("Failed to parse item '%s': %s", item, e) continue +======= + ) -> Set[str]: + keywords = [] + # use re.escape(start_token) if start_token contains special chars like */&/^ etc. + matches = re.findall(rf"{start_token}[^\n]+\n?", response) + + for match in matches: + match = match[len(start_token) :].strip() + keywords.extend( + k.lower() if lowercase else k + for k in re.split(r"[,,]+", match) + if len(k.strip()) > 1 + ) + + # if the keyword consists of multiple words, split into sub-words (removing stopwords) + results = set(keywords) + for token in keywords: + sub_tokens = re.findall(r"\w+", token) + if len(sub_tokens) > 1: + results.update( + w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language) + ) +>>>>>>> 78011d3 (Refactor: text2germlin with PCgraph framework (#50)) return results diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/prompt_generate.py index 82326f000..058d1bce9 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/prompt_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/prompt_generate.py @@ -52,11 +52,11 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: few_shot_example = self._load_few_shot_example(example_name) meta_prompt = prompt_tpl.generate_extract_prompt_template.format( - few_shot_text=few_shot_example.get('text', ''), - few_shot_prompt=few_shot_example.get('prompt', ''), + few_shot_text=few_shot_example.get("text", ""), + few_shot_prompt=few_shot_example.get("prompt", ""), user_text=source_text, user_scenario=scenario, - language=prompt_tpl.llm_settings.language + language=prompt_tpl.llm_settings.language, ) log.debug("Meta-prompt sent to LLM: %s", meta_prompt) generated_prompt = self.llm.generate(prompt=meta_prompt) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py index 79fb33b4f..565d79023 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/property_graph_extract.py @@ -67,9 +67,9 @@ def filter_item(schema, items) -> List[Dict[str, Any]]: item_type = item["type"] if item_type == "vertex": label = item["label"] - non_nullable_keys = set( - properties_map[item_type][label]["properties"] - ).difference(set(properties_map[item_type][label]["nullable_keys"])) + non_nullable_keys = set(properties_map[item_type][label]["properties"]).difference( + set(properties_map[item_type][label]["nullable_keys"]) + ) for key in non_nullable_keys: if key not in item["properties"]: item["properties"][key] = "NULL" @@ -82,9 +82,7 @@ def filter_item(schema, items) -> List[Dict[str, Any]]: class PropertyGraphExtract: - def __init__( - self, llm: BaseLLM, example_prompt: str = prompt.extract_graph_prompt - ) -> None: + def __init__(self, llm: BaseLLM, example_prompt: str = prompt.extract_graph_prompt) -> None: self.llm = llm self.example_prompt = example_prompt self.NECESSARY_ITEM_KEYS = {"label", "type", "properties"} # pylint: disable=invalid-name @@ -142,9 +140,7 @@ def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: and "vertices" in property_graph and "edges" in property_graph ): - log.critical( - "Invalid property graph format; expecting 'vertices' and 'edges'." - ) + log.critical("Invalid property graph format; expecting 'vertices' and 'edges'.") return items # Create sets for valid vertex and edge labels based on the schema @@ -154,9 +150,7 @@ def _extract_and_filter_label(self, schema, text) -> List[Dict[str, Any]]: def process_items(item_list, valid_labels, item_type): for item in item_list: if not isinstance(item, dict): - log.warning( - "Invalid property graph item type '%s'.", type(item) - ) + log.warning("Invalid property graph item type '%s'.", type(item)) continue if not self.NECESSARY_ITEM_KEYS.issubset(item.keys()): log.warning("Invalid item keys '%s'.", item.keys()) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/schema_build.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/schema_build.py index 53587381a..928948413 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/schema_build.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/schema_build.py @@ -34,7 +34,9 @@ def __init__( ): self.llm = llm or LLMs().get_chat_llm() # TODO: use a basic format for it - self.schema_prompt = schema_prompt or """ + self.schema_prompt = ( + schema_prompt + or """ You are a Graph Schema Generator for Apache HugeGraph. Based on the following three parts of content, output a Schema JSON that complies with HugeGraph specifications: @@ -53,6 +55,7 @@ def __init__( - Ensure the schema follows HugeGraph specifications - Do not include comments or extra fields. """ + ) def _format_raw_texts(self, raw_texts: List[str]) -> str: return "\n".join([f"- {text}" for text in raw_texts]) @@ -86,18 +89,15 @@ def build_prompt( self, raw_texts: List[str], query_examples: List[Dict[str, str]], - few_shot_schema: Dict[str, Any] + few_shot_schema: Dict[str, Any], ) -> str: return self.schema_prompt.format( raw_texts=self._format_raw_texts(raw_texts), query_examples=self._format_query_examples(query_examples), - few_shot_schema=self._format_few_shot_schema(few_shot_schema) + few_shot_schema=self._format_few_shot_schema(few_shot_schema), ) - def run( - self, - context: Dict[str, Any] - ) -> Dict[str, Any]: + def run(self, context: Dict[str, Any]) -> Dict[str, Any]: """Generate schema from context containing raw_texts, query_examples and few_shot_schema. Args: diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py index 6d3418c00..f941098b1 100644 --- a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -33,6 +33,11 @@ class WkFlowInput(GParam): source_text: str = None # Original text scenario: str = None # Scenario description example_name: str = None # Example name + # Fields for Text2Gremlin + query: str = None + example_num: int = None + gremlin_prompt: str = None + requested_outputs: Optional[List[str]] = None def reset(self, _: CStatus) -> None: self.texts = None @@ -49,6 +54,11 @@ def reset(self, _: CStatus) -> None: self.source_text = None self.scenario = None self.example_name = None + # Text2Gremlin related configuration + self.query = None + self.example_num = None + self.gremlin_prompt = None + self.requested_outputs = None class WkFlowState(GParam): @@ -66,6 +76,12 @@ class WkFlowState(GParam): keywords_embeddings = None generated_extract_prompt: Optional[str] = None + # Fields for Text2Gremlin results + match_result: Optional[List[dict]] = None + result: Optional[str] = None + raw_result: Optional[str] = None + template_exec_res: Optional[Any] = None + raw_exec_res: Optional[Any] = None def setup(self): self.schema = None @@ -74,7 +90,7 @@ def setup(self): self.edges = None self.vertices = None self.triples = None - self.call_count = None + self.call_count = 0 self.keywords = None self.vector_result = None @@ -82,6 +98,12 @@ def setup(self): self.keywords_embeddings = None self.generated_extract_prompt = None + # Text2Gremlin results reset + self.match_result = [] + self.result = "" + self.raw_result = "" + self.template_exec_res = "" + self.raw_exec_res = "" return CStatus() @@ -94,11 +116,7 @@ def to_json(self): dict: A dictionary containing non-None instance members and their serialized values. """ # Only export instance attributes (excluding methods and class attributes) whose values are not None - return { - k: v - for k, v in self.__dict__.items() - if not k.startswith("_") and v is not None - } + return {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} # Implement a method that assigns keys from data_json as WkFlowState member variables def assign_from_json(self, data_json: dict): diff --git a/hugegraph-llm/src/hugegraph_llm/utils/anchor.py b/hugegraph-llm/src/hugegraph_llm/utils/anchor.py index d5f687a94..4542a7fd9 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/anchor.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/anchor.py @@ -15,16 +15,17 @@ from pathlib import Path + def get_project_root() -> Path: """ Returns the Path object of the project root directory. - - The function searches for common project root indicators like pyproject.toml + + The function searches for common project root indicators like pyproject.toml or .git directory by traversing up the directory tree from the current file location. - + Returns: Path: The absolute path to the project root directory - + Raises: RuntimeError: If no project root indicators could be found """ diff --git a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py index b07de6f4b..2914c4b28 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/decorators.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/decorators.py @@ -109,6 +109,7 @@ def sync_wrapper(*args: Any, **kwargs: Any) -> Any: def with_task_id(func: Callable) -> Callable: def wrapper(*args: Any, **kwargs: Any) -> Any: import uuid + task_id = f"task_{str(uuid.uuid4())[:8]}" log.debug("New task created with id: %s", task_id) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py index 55e50eadd..b2f485cea 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/embedding_utils.py @@ -24,7 +24,9 @@ from hugegraph_llm.models.embeddings.base import BaseEmbedding -async def _get_batch_with_progress(embedding: BaseEmbedding, batch: list[str], pbar: tqdm) -> list[Any]: +async def _get_batch_with_progress( + embedding: BaseEmbedding, batch: list[str], pbar: tqdm +) -> list[Any]: result = await embedding.async_get_texts_embeddings(batch) pbar.update(1) return result @@ -58,10 +60,7 @@ async def get_embeddings_parallel(embedding: BaseEmbedding, vids: list[str]) -> embeddings = [] with tqdm(total=len(vid_batches)) as pbar: # Create tasks for each batch with progress bar updates - tasks = [ - _get_batch_with_progress(embedding, batch, pbar) - for batch in vid_batches - ] + tasks = [_get_batch_with_progress(embedding, batch, pbar) for batch in vid_batches] # Use asyncio.gather() to preserve order batch_results = await asyncio.gather(*tasks) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index ccace69f2..7b870033a 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -45,13 +45,9 @@ def get_graph_index_info(): def get_graph_index_info_old(): - builder = KgBuilder( - LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() - ) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) graph_summary_info = builder.fetch_graph_data().run() - folder_name = get_index_folder_name( - huge_settings.graph_name, huge_settings.graph_space - ) + folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) filename_prefix = get_filename_prefix( llm_settings.embedding_type, getattr(builder.embedding, "model_name", None) @@ -66,16 +62,12 @@ def get_graph_index_info_old(): def clean_all_graph_index(): - folder_name = get_index_folder_name( - huge_settings.graph_name, huge_settings.graph_space - ) + folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) filename_prefix = get_filename_prefix( llm_settings.embedding_type, getattr(Embeddings().get_embedding(), "model_name", None), ) - VectorIndex.clean( - str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix - ) + VectorIndex.clean(str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix) VectorIndex.clean( str(os.path.join(resource_path, folder_name, "gremlin_examples")), filename_prefix, @@ -107,18 +99,14 @@ def parse_schema(schema: str, builder: KgBuilder) -> Optional[str]: def extract_graph_origin(input_file, input_text, schema, example_prompt) -> str: texts = read_documents(input_file, input_text) - builder = KgBuilder( - LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() - ) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) if not schema: return "ERROR: please input with correct schema/format." error_message = parse_schema(schema, builder) if error_message: return error_message - builder.chunk_split(texts, "document", "zh").extract_info( - example_prompt, "property_graph" - ) + builder.chunk_split(texts, "document", "zh").extract_info(example_prompt, "property_graph") try: context = builder.run() @@ -168,9 +156,7 @@ def update_vid_embedding(): def update_vid_embedding_old(): - builder = KgBuilder( - LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() - ) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) builder.fetch_graph_data().build_vertex_id_semantic_index() log.debug("Operators: %s", builder.operators) try: @@ -199,9 +185,7 @@ def import_graph_data_old(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: data_json = json.loads(data.strip()) log.debug("Import graph data: %s", data) - builder = KgBuilder( - LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() - ) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) if schema: error_message = parse_schema(schema, builder) if error_message: @@ -222,9 +206,7 @@ def import_graph_data_old(data: str, schema: str) -> Union[str, Dict[str, Any]]: def build_schema(input_text, query_example, few_shot): scheduler = SchedulerSingleton.get_instance() try: - return scheduler.schedule_flow( - "build_schema", input_text, query_example, few_shot - ) + return scheduler.schedule_flow("build_schema", input_text, query_example, few_shot) except (TypeError, ValueError) as e: raise gr.Error(f"Schema generation failed: {e}") @@ -257,9 +239,7 @@ def build_schema_old(input_text, query_example, few_shot): except json.JSONDecodeError as e: raise gr.Error(f"Query Examples is not in a valid JSON format: {e}") from e - builder = KgBuilder( - LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() - ) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) try: schema = builder.build_schema().run(context) except Exception as e: diff --git a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py index 1d02b45d3..147c0074c 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/hugegraph_utils.py @@ -53,7 +53,9 @@ def init_hg_test_data(): schema = client.schema() schema.propertyKey("name").asText().ifNotExist().create() schema.propertyKey("birthDate").asText().ifNotExist().create() - schema.vertexLabel("Person").properties("name", "birthDate").useCustomizeStringId().ifNotExist().create() + schema.vertexLabel("Person").properties( + "name", "birthDate" + ).useCustomizeStringId().ifNotExist().create() schema.vertexLabel("Movie").properties("name").useCustomizeStringId().ifNotExist().create() schema.edgeLabel("ActedIn").sourceLabel("Person").targetLabel("Movie").ifNotExist().create() @@ -110,13 +112,13 @@ def backup_data(): files = { "vertices.json": f"g.V().limit({MAX_VERTICES})" - f".aggregate('vertices').count().as('count').select('count','vertices')", + f".aggregate('vertices').count().as('count').select('count','vertices')", "edges.json": f"g.E().limit({MAX_EDGES}).aggregate('edges').count().as('count').select('count','edges')", - "schema.json": client.schema().getSchema(_format="groovy") + "schema.json": client.schema().getSchema(_format="groovy"), } vertexlabels = client.schema().getSchema()["vertexlabels"] - all_pk_flag = all(data.get('id_strategy') == 'PRIMARY_KEY' for data in vertexlabels) + all_pk_flag = all(data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels) for filename, query in files.items(): write_backup_file(client, backup_subdir, filename, query, all_pk_flag) @@ -137,14 +139,22 @@ def write_backup_file(client, backup_subdir, filename, query, all_pk_flag): json.dump(data, f, ensure_ascii=False) elif filename == "vertices.json": data_full = client.gremlin().exec(query)["data"][0]["vertices"] - data = [{key: value for key, value in vertex.items() if key != "id"} - for vertex in data_full] if all_pk_flag else data_full + data = ( + [ + {key: value for key, value in vertex.items() if key != "id"} + for vertex in data_full + ] + if all_pk_flag + else data_full + ) json.dump(data, f, ensure_ascii=False) elif filename == "schema.json": data_full = query if isinstance(data_full, dict) and "schema" in data_full: groovy_filename = filename.replace(".json", ".groovy") - with open(os.path.join(backup_subdir, groovy_filename), "w", encoding="utf-8") as groovy_file: + with open( + os.path.join(backup_subdir, groovy_filename), "w", encoding="utf-8" + ) as groovy_file: groovy_file.write(str(data_full["schema"])) else: data = data_full @@ -171,7 +181,7 @@ def manage_backup_retention(): raise Exception("Failed to manage backup retention") from e -#TODO: In the path demo/rag_demo/configs_block.py, +# TODO: In the path demo/rag_demo/configs_block.py, # there is a function test_api_connection that is similar to this function, # but it is not straightforward to reuse def check_graph_db_connection(url: str, name: str, user: str, pwd: str, graph_space: str) -> bool: diff --git a/hugegraph-llm/src/hugegraph_llm/utils/log.py b/hugegraph-llm/src/hugegraph_llm/utils/log.py index 7076869fd..b64017454 100755 --- a/hugegraph-llm/src/hugegraph_llm/utils/log.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/log.py @@ -31,7 +31,7 @@ log_level=INFO, logger_name="root", propagate_logs=True, - stdout_logging=True + stdout_logging=True, ) # Initialize custom logger diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py index 138b0d359..301a6bdab 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py @@ -50,9 +50,7 @@ def read_documents(input_file, input_text): texts.append(text) elif full_path.endswith(".pdf"): # TODO: support PDF file - raise gr.Error( - "PDF will be supported later! Try to upload text/docx now" - ) + raise gr.Error("PDF will be supported later! Try to upload text/docx now") else: raise gr.Error("Please input txt or docx file.") else: @@ -62,9 +60,7 @@ def read_documents(input_file, input_text): # pylint: disable=C0301 def get_vector_index_info(): - folder_name = get_index_folder_name( - huge_settings.graph_name, huge_settings.graph_space - ) + folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) filename_prefix = get_filename_prefix( llm_settings.embedding_type, model_map.get(llm_settings.embedding_type) ) @@ -91,15 +87,11 @@ def get_vector_index_info(): def clean_vector_index(): - folder_name = get_index_folder_name( - huge_settings.graph_name, huge_settings.graph_space - ) + folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) filename_prefix = get_filename_prefix( llm_settings.embedding_type, model_map.get(llm_settings.embedding_type) ) - VectorIndex.clean( - str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix - ) + VectorIndex.clean(str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix) gr.Info("Clean vector index successfully!") diff --git a/hugegraph-llm/src/tests/config/test_config.py b/hugegraph-llm/src/tests/config/test_config.py index 6c803135f..7f480befa 100644 --- a/hugegraph-llm/src/tests/config/test_config.py +++ b/hugegraph-llm/src/tests/config/test_config.py @@ -23,5 +23,6 @@ class TestConfig(unittest.TestCase): def test_config(self): import nltk from hugegraph_llm.config import resource_path + nltk.data.path.append(resource_path) nltk.data.find("corpora/stopwords") diff --git a/hugegraph-llm/src/tests/models/embeddings/test_openai_embedding.py b/hugegraph-llm/src/tests/models/embeddings/test_openai_embedding.py index b9ded0f6c..f7afd15c6 100644 --- a/hugegraph-llm/src/tests/models/embeddings/test_openai_embedding.py +++ b/hugegraph-llm/src/tests/models/embeddings/test_openai_embedding.py @@ -22,6 +22,7 @@ class TestOpenAIEmbedding(unittest.TestCase): def test_embedding_dimension(self): from hugegraph_llm.models.embeddings.openai import OpenAIEmbedding + embedding = OpenAIEmbedding(api_key="") result = embedding.get_text_embedding("hello world!") print(result) diff --git a/hugegraph-llm/src/tests/models/llms/test_ollama_client.py b/hugegraph-llm/src/tests/models/llms/test_ollama_client.py index caabe2a8e..7ad914468 100644 --- a/hugegraph-llm/src/tests/models/llms/test_ollama_client.py +++ b/hugegraph-llm/src/tests/models/llms/test_ollama_client.py @@ -28,7 +28,10 @@ def test_generate(self): def test_stream_generate(self): ollama_client = OllamaClient(model="llama3:8b-instruct-fp16") + def on_token_callback(chunk): print(chunk, end="", flush=True) - ollama_client.generate_streaming(prompt="What is the capital of France?", - on_token_callback=on_token_callback) + + ollama_client.generate_streaming( + prompt="What is the capital of France?", on_token_callback=on_token_callback + ) diff --git a/hugegraph-llm/src/tests/operators/common_op/test_check_schema.py b/hugegraph-llm/src/tests/operators/common_op/test_check_schema.py index d20a198f2..317d02879 100644 --- a/hugegraph-llm/src/tests/operators/common_op/test_check_schema.py +++ b/hugegraph-llm/src/tests/operators/common_op/test_check_schema.py @@ -26,12 +26,7 @@ def setUp(self): def test_schema_check_with_valid_input(self): data = { - "vertexlabels": [ - { - "name": "person", - "properties": ["name", "age", "occupation"] - } - ], + "vertexlabels": [{"name": "person", "properties": ["name", "age", "occupation"]}], "edgelabels": [ { "name": "knows", @@ -41,7 +36,7 @@ def test_schema_check_with_valid_input(self): ], } check_schema = CheckSchema(data) - self.assertEqual(check_schema.run(), {'schema': data}) + self.assertEqual(check_schema.run(), {"schema": data}) def test_schema_check_with_invalid_input(self): data = "invalid input" diff --git a/hugegraph-llm/src/tests/operators/common_op/test_nltk_helper.py b/hugegraph-llm/src/tests/operators/common_op/test_nltk_helper.py index 5ad73ed6f..b557cfc1b 100644 --- a/hugegraph-llm/src/tests/operators/common_op/test_nltk_helper.py +++ b/hugegraph-llm/src/tests/operators/common_op/test_nltk_helper.py @@ -22,6 +22,7 @@ class TestNLTKHelper(unittest.TestCase): def test_stopwords(self): from hugegraph_llm.operators.common_op.nltk_helper import NLTKHelper + nltk_helper = NLTKHelper() stopwords = nltk_helper.stopwords() print(stopwords) diff --git a/hugegraph-python-client/src/pyhugegraph/api/auth.py b/hugegraph-python-client/src/pyhugegraph/api/auth.py index 90b3e98d0..ab7d66169 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/auth.py +++ b/hugegraph-python-client/src/pyhugegraph/api/auth.py @@ -84,9 +84,7 @@ def create_group(self, group_name, group_description=None) -> Optional[Dict]: return self._invoke_request(data=json.dumps(data)) @router.http("DELETE", "auth/groups/{group_id}") - def delete_group( - self, group_id # pylint: disable=unused-argument - ) -> Optional[Dict]: + def delete_group(self, group_id) -> Optional[Dict]: # pylint: disable=unused-argument return self._invoke_request() @router.http("PUT", "auth/groups/{group_id}") @@ -116,9 +114,7 @@ def grant_accesses(self, group_id, target_id, access_permission) -> Optional[Dic ) @router.http("DELETE", "auth/accesses/{access_id}") - def revoke_accesses( - self, access_id # pylint: disable=unused-argument - ) -> Optional[Dict]: + def revoke_accesses(self, access_id) -> Optional[Dict]: # pylint: disable=unused-argument return self._invoke_request() @router.http("PUT", "auth/accesses/{access_id}") @@ -130,9 +126,7 @@ def modify_accesses( return self._invoke_request(data=json.dumps(data)) @router.http("GET", "auth/accesses/{access_id}") - def get_accesses( - self, access_id # pylint: disable=unused-argument - ) -> Optional[Dict]: + def get_accesses(self, access_id) -> Optional[Dict]: # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "auth/accesses") @@ -205,9 +199,7 @@ def update_belong( return self._invoke_request(data=json.dumps(data)) @router.http("GET", "auth/belongs/{belong_id}") - def get_belong( - self, belong_id # pylint: disable=unused-argument - ) -> Optional[Dict]: + def get_belong(self, belong_id) -> Optional[Dict]: # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "auth/belongs") diff --git a/hugegraph-python-client/src/pyhugegraph/api/graph.py b/hugegraph-python-client/src/pyhugegraph/api/graph.py index 907e01a5b..4555eeda4 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/graph.py +++ b/hugegraph-python-client/src/pyhugegraph/api/graph.py @@ -141,9 +141,7 @@ def addEdges(self, input_data) -> Optional[List[EdgeData]]: def appendEdge( self, edge_id, properties # pylint: disable=unused-argument ) -> Optional[EdgeData]: - if response := self._invoke_request( - data=json.dumps({"properties": properties}) - ): + if response := self._invoke_request(data=json.dumps({"properties": properties})): return EdgeData(response) return None @@ -151,16 +149,12 @@ def appendEdge( def eliminateEdge( self, edge_id, properties # pylint: disable=unused-argument ) -> Optional[EdgeData]: - if response := self._invoke_request( - data=json.dumps({"properties": properties}) - ): + if response := self._invoke_request(data=json.dumps({"properties": properties})): return EdgeData(response) return None @router.http("GET", "graph/edges/{edge_id}") - def getEdgeById( - self, edge_id # pylint: disable=unused-argument - ) -> Optional[EdgeData]: + def getEdgeById(self, edge_id) -> Optional[EdgeData]: # pylint: disable=unused-argument if response := self._invoke_request(): return EdgeData(response) return None diff --git a/hugegraph-python-client/src/pyhugegraph/api/schema.py b/hugegraph-python-client/src/pyhugegraph/api/schema.py index 8b4f54cfe..7e8926678 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/schema.py +++ b/hugegraph-python-client/src/pyhugegraph/api/schema.py @@ -64,9 +64,7 @@ def indexLabel(self, name): return index_label @router.http("GET", "schema?format={_format}") - def getSchema( - self, _format: str = "json" # pylint: disable=unused-argument - ) -> Optional[Dict]: + def getSchema(self, _format: str = "json") -> Optional[Dict]: # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "schema/propertykeys/{property_name}") @@ -84,9 +82,7 @@ def getPropertyKeys(self) -> Optional[List[PropertyKeyData]]: return None @router.http("GET", "schema/vertexlabels/{name}") - def getVertexLabel( - self, name # pylint: disable=unused-argument - ) -> Optional[VertexLabelData]: + def getVertexLabel(self, name) -> Optional[VertexLabelData]: # pylint: disable=unused-argument if response := self._invoke_request(): return VertexLabelData(response) log.error("VertexLabel not found: %s", str(response)) @@ -128,9 +124,7 @@ def getRelations(self) -> Optional[List[str]]: return None @router.http("GET", "schema/indexlabels/{name}") - def getIndexLabel( - self, name # pylint: disable=unused-argument - ) -> Optional[IndexLabelData]: + def getIndexLabel(self, name) -> Optional[IndexLabelData]: # pylint: disable=unused-argument if response := self._invoke_request(): return IndexLabelData(response) log.error("IndexLabel not found: %s", str(response)) diff --git a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/index_label.py b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/index_label.py index 252d487bd..acef8f968 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/schema_manage/index_label.py +++ b/hugegraph-python-client/src/pyhugegraph/api/schema_manage/index_label.py @@ -83,11 +83,13 @@ def ifNotExist(self) -> "IndexLabel": @decorator_create def create(self): dic = self._parameter_holder.get_dic() - data = {"name": dic["name"], - "base_type": dic["base_type"], - "base_value": dic["base_value"], - "index_type": dic["index_type"], - "fields": list(dic["fields"])} + data = { + "name": dic["name"], + "base_type": dic["base_type"], + "base_value": dic["base_value"], + "index_type": dic["index_type"], + "fields": list(dic["fields"]), + } path = "schema/indexlabels" self.clean_parameter_holder() if response := self._sess.request(path, "POST", data=json.dumps(data)): diff --git a/hugegraph-python-client/src/pyhugegraph/api/services.py b/hugegraph-python-client/src/pyhugegraph/api/services.py index e086ae13e..f353673db 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/services.py +++ b/hugegraph-python-client/src/pyhugegraph/api/services.py @@ -87,9 +87,7 @@ def list_services(self, graphspace: str): # pylint: disable=unused-argument return self._invoke_request() @router.http("GET", "/graphspaces/{graphspace}/services/{service}") - def get_service( - self, graphspace: str, service: str # pylint: disable=unused-argument - ): + def get_service(self, graphspace: str, service: str): # pylint: disable=unused-argument """ Retrieve the details of a specific service. @@ -112,9 +110,7 @@ def get_service( """ return self._invoke_request() - def delete_service( - self, graphspace: str, service: str # pylint: disable=unused-argument - ): + def delete_service(self, graphspace: str, service: str): # pylint: disable=unused-argument """ Delete a specific service within a graph space. diff --git a/hugegraph-python-client/src/pyhugegraph/api/traverser.py b/hugegraph-python-client/src/pyhugegraph/api/traverser.py index 628c3f4bd..72dddb07a 100644 --- a/hugegraph-python-client/src/pyhugegraph/api/traverser.py +++ b/hugegraph-python-client/src/pyhugegraph/api/traverser.py @@ -26,33 +26,23 @@ class TraverserManager(HugeParamsBase): def k_out(self, source_id, max_depth): # pylint: disable=unused-argument return self._invoke_request() - @router.http( - "GET", 'traversers/kneighbor?source="{source_id}"&max_depth={max_depth}' - ) + @router.http("GET", 'traversers/kneighbor?source="{source_id}"&max_depth={max_depth}') def k_neighbor(self, source_id, max_depth): # pylint: disable=unused-argument return self._invoke_request() - @router.http( - "GET", 'traversers/sameneighbors?vertex="{vertex_id}"&other="{other_id}"' - ) + @router.http("GET", 'traversers/sameneighbors?vertex="{vertex_id}"&other="{other_id}"') def same_neighbors(self, vertex_id, other_id): # pylint: disable=unused-argument return self._invoke_request() - @router.http( - "GET", 'traversers/jaccardsimilarity?vertex="{vertex_id}"&other="{other_id}"' - ) - def jaccard_similarity( - self, vertex_id, other_id # pylint: disable=unused-argument - ): + @router.http("GET", 'traversers/jaccardsimilarity?vertex="{vertex_id}"&other="{other_id}"') + def jaccard_similarity(self, vertex_id, other_id): # pylint: disable=unused-argument return self._invoke_request() @router.http( "GET", 'traversers/shortestpath?source="{source_id}"&target="{target_id}"&max_depth={max_depth}', ) - def shortest_path( - self, source_id, target_id, max_depth # pylint: disable=unused-argument - ): + def shortest_path(self, source_id, target_id, max_depth): # pylint: disable=unused-argument return self._invoke_request() @router.http( @@ -78,9 +68,7 @@ def weighted_shortest_path( "GET", 'traversers/singlesourceshortestpath?source="{source_id}"&max_depth={max_depth}', ) - def single_source_shortest_path( - self, source_id, max_depth # pylint: disable=unused-argument - ): + def single_source_shortest_path(self, source_id, max_depth): # pylint: disable=unused-argument return self._invoke_request() @router.http("POST", "traversers/multinodeshortestpath") @@ -114,9 +102,17 @@ def multi_node_shortest_path( def paths(self, source_id, target_id, max_depth): # pylint: disable=unused-argument return self._invoke_request() - @router.http("POST", 'traversers/paths') + @router.http("POST", "traversers/paths") def advanced_paths( - self, sources, targets, step, max_depth, nearest=True, capacity=10000000, limit=10, with_vertex=False + self, + sources, + targets, + step, + max_depth, + nearest=True, + capacity=10000000, + limit=10, + with_vertex=False, ): return self._invoke_request( data=json.dumps( @@ -133,7 +129,6 @@ def advanced_paths( ) ) - @router.http("POST", "traversers/customizedpaths") def customized_paths( self, sources, steps, sort_by="INCR", with_vertex=True, capacity=-1, limit=-1 @@ -152,9 +147,7 @@ def customized_paths( ) @router.http("POST", "traversers/templatepaths") - def template_paths( - self, sources, targets, steps, capacity=10000, limit=10, with_vertex=True - ): + def template_paths(self, sources, targets, steps, capacity=10000, limit=10, with_vertex=True): return self._invoke_request( data=json.dumps( { @@ -172,9 +165,7 @@ def template_paths( "GET", 'traversers/crosspoints?source="{source_id}"&target="{target_id}"&max_depth={max_depth}', ) - def crosspoints( - self, source_id, target_id, max_depth # pylint: disable=unused-argument - ): + def crosspoints(self, source_id, target_id, max_depth): # pylint: disable=unused-argument return self._invoke_request() @router.http("POST", "traversers/customizedcrosspoints") diff --git a/hugegraph-python-client/src/pyhugegraph/client.py b/hugegraph-python-client/src/pyhugegraph/client.py index 3b0301321..c9f4d1027 100644 --- a/hugegraph-python-client/src/pyhugegraph/client.py +++ b/hugegraph-python-client/src/pyhugegraph/client.py @@ -53,7 +53,7 @@ def __init__( user: str, pwd: str, graphspace: Optional[str] = None, - timeout: Optional[tuple[float, float]] = None + timeout: Optional[tuple[float, float]] = None, ): self.cfg = HGraphConfig(url, user, pwd, graph, graphspace, timeout or (0.5, 15.0)) diff --git a/hugegraph-python-client/src/pyhugegraph/example/hugegraph_example.py b/hugegraph-python-client/src/pyhugegraph/example/hugegraph_example.py index 4bb70dba5..d5cc0eb9d 100644 --- a/hugegraph-python-client/src/pyhugegraph/example/hugegraph_example.py +++ b/hugegraph-python-client/src/pyhugegraph/example/hugegraph_example.py @@ -26,15 +26,13 @@ schema = client.schema() schema.propertyKey("name").asText().ifNotExist().create() schema.propertyKey("birthDate").asText().ifNotExist().create() - schema.vertexLabel("Person").properties( - "name", "birthDate" - ).usePrimaryKeyId().primaryKeys("name").ifNotExist().create() - schema.vertexLabel("Movie").properties("name").usePrimaryKeyId().primaryKeys( + schema.vertexLabel("Person").properties("name", "birthDate").usePrimaryKeyId().primaryKeys( "name" ).ifNotExist().create() - schema.edgeLabel("ActedIn").sourceLabel("Person").targetLabel( - "Movie" + schema.vertexLabel("Movie").properties("name").usePrimaryKeyId().primaryKeys( + "name" ).ifNotExist().create() + schema.edgeLabel("ActedIn").sourceLabel("Person").targetLabel("Movie").ifNotExist().create() print(schema.getVertexLabels()) print(schema.getEdgeLabels()) @@ -47,9 +45,7 @@ p2 = g.addVertex("Person", {"name": "Robert De Niro", "birthDate": "1943-08-17"}) m1 = g.addVertex("Movie", {"name": "The Godfather"}) m2 = g.addVertex("Movie", {"name": "The Godfather Part II"}) - m3 = g.addVertex( - "Movie", {"name": "The Godfather Coda The Death of Michael Corleone"} - ) + m3 = g.addVertex("Movie", {"name": "The Godfather Coda The Death of Michael Corleone"}) # add Edge g.addEdge("ActedIn", p1.id, m1.id, {}) diff --git a/hugegraph-python-client/src/pyhugegraph/structure/property_key_data.py b/hugegraph-python-client/src/pyhugegraph/structure/property_key_data.py index 6fb7c36f0..ff50d9b2f 100644 --- a/hugegraph-python-client/src/pyhugegraph/structure/property_key_data.py +++ b/hugegraph-python-client/src/pyhugegraph/structure/property_key_data.py @@ -62,5 +62,7 @@ def userdata(self): return self.__user_data def __repr__(self): - res = f"name: {self.__name}, cardinality: {self.__cardinality}, data_type: {self.__data_type}" + res = ( + f"name: {self.__name}, cardinality: {self.__cardinality}, data_type: {self.__data_type}" + ) return res diff --git a/hugegraph-python-client/src/pyhugegraph/utils/huge_config.py b/hugegraph-python-client/src/pyhugegraph/utils/huge_config.py index 3f6d78b95..429c07c6b 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/huge_config.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/huge_config.py @@ -39,7 +39,7 @@ class HGraphConfig: def __post_init__(self): # Add URL prefix compatibility check - if self.url and not self.url.startswith('http'): + if self.url and not self.url.startswith("http"): self.url = f"http://{self.url}" if self.graphspace and self.graphspace.strip(): @@ -47,9 +47,7 @@ def __post_init__(self): else: try: - response = requests.get( - f"{self.url}/versions", timeout=0.5 - ) + response = requests.get(f"{self.url}/versions", timeout=0.5) core = response.json()["versions"]["core"] log.info( # pylint: disable=logging-fstring-interpolation f"Retrieved API version information from the server: {core}." @@ -71,4 +69,6 @@ def __post_init__(self): except Exception: # pylint: disable=broad-exception-caught exc_type, exc_value, tb = sys.exc_info() traceback.print_exception(exc_type, exc_value, tb) - log.warning("Failed to retrieve API version information from the server, reverting to default v1.") + log.warning( + "Failed to retrieve API version information from the server, reverting to default v1." + ) diff --git a/hugegraph-python-client/src/pyhugegraph/utils/huge_router.py b/hugegraph-python-client/src/pyhugegraph/utils/huge_router.py index 0db81ed1c..f4a38a418 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/huge_router.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/huge_router.py @@ -81,9 +81,7 @@ def wrapper(self: "HGraphContext", *args: Any, **kwargs: Any) -> Any: route = RouterRegistry().routers.get(func.__qualname__) if route.request_func is None: - route.request_func = functools.partial( - self.session.request, method=method - ) + route.request_func = functools.partial(self.session.request, method=method) return func(self, *args, **kwargs) @@ -134,9 +132,7 @@ def wrapper(self: "HGraphContext", *args: Any, **kwargs: Any) -> Any: formatted_path = path # Use functools.partial to create a partial function for making requests - make_request = functools.partial( - self.session.request, formatted_path, method - ) + make_request = functools.partial(self.session.request, formatted_path, method) # Store the partial function on the instance setattr(self, f"_{func.__name__}_request", make_request) diff --git a/hugegraph-python-client/src/pyhugegraph/utils/log.py b/hugegraph-python-client/src/pyhugegraph/utils/log.py index b263d32d5..c6f6bd074 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/log.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/log.py @@ -138,7 +138,9 @@ def init_logger( def _cached_log_file(filename): """Cache the opened file object""" # Use 1K buffer if writing to cloud storage - with open(filename, "a", buffering=_determine_buffer_size(filename), encoding="utf-8") as file_io: + with open( + filename, "a", buffering=_determine_buffer_size(filename), encoding="utf-8" + ) as file_io: atexit.register(file_io.close) return file_io diff --git a/hugegraph-python-client/src/pyhugegraph/utils/util.py b/hugegraph-python-client/src/pyhugegraph/utils/util.py index 90f27c24a..56a135547 100644 --- a/hugegraph-python-client/src/pyhugegraph/utils/util.py +++ b/hugegraph-python-client/src/pyhugegraph/utils/util.py @@ -44,7 +44,9 @@ def create_exception(response_content): def check_if_authorized(response): if response.status_code == 401: - raise NotAuthorizedError(f"Please check your username and password. {str(response.content)}") + raise NotAuthorizedError( + f"Please check your username and password. {str(response.content)}" + ) return True @@ -56,8 +58,12 @@ def check_if_success(response, error=None): req = response.request req_body = req.body if req.body else "Empty body" response_body = response.text if response.text else "Empty body" - log.error("Error-Client: Request URL: %s, Request Body: %s, Response Body: %s", - req.url, req_body, response_body) + log.error( + "Error-Client: Request URL: %s, Request Body: %s, Response Body: %s", + req.url, + req_body, + response_body, + ) raise error return True @@ -103,9 +109,14 @@ def __call__(self, response: requests.Response, method: str, path: str): details = "key 'exception' not found" req_body = response.request.body if response.request.body else "Empty body" - req_body = req_body.encode('utf-8').decode('unicode_escape') - log.error("%s: %s\n[Body]: %s\n[Server Exception]: %s", - method, str(e).encode('utf-8').decode('unicode_escape'), req_body, details) + req_body = req_body.encode("utf-8").decode("unicode_escape") + log.error( + "%s: %s\n[Body]: %s\n[Server Exception]: %s", + method, + str(e).encode("utf-8").decode("unicode_escape"), + req_body, + details, + ) if response.status_code == 404: raise NotFoundError(response.content) from e diff --git a/hugegraph-python-client/src/tests/api/test_auth.py b/hugegraph-python-client/src/tests/api/test_auth.py index d2d30cf26..10e6bad7f 100644 --- a/hugegraph-python-client/src/tests/api/test_auth.py +++ b/hugegraph-python-client/src/tests/api/test_auth.py @@ -98,9 +98,7 @@ def test_group_operations(self): self.assertEqual(group["group_name"], "test_group") # Modify the group - group = self.auth.modify_group( - group["id"], group_description="test_description" - ) + group = self.auth.modify_group(group["id"], group_description="test_description") self.assertEqual(group["group_description"], "test_description") # Delete the group @@ -135,9 +133,7 @@ def test_target_operations(self): [{"type": "VERTEX", "label": "person", "properties": {"city": "Shanghai"}}], ) # Verify the target was modified - self.assertEqual( - target["target_resources"][0]["properties"]["city"], "Shanghai" - ) + self.assertEqual(target["target_resources"][0]["properties"]["city"], "Shanghai") # Delete the target self.auth.delete_target(target["id"]) diff --git a/hugegraph-python-client/src/tests/api/test_version.py b/hugegraph-python-client/src/tests/api/test_version.py index 1d6325dfd..44c5f376c 100644 --- a/hugegraph-python-client/src/tests/api/test_version.py +++ b/hugegraph-python-client/src/tests/api/test_version.py @@ -42,7 +42,7 @@ def tearDown(self): def test_version(self): version = self.version.version() self.assertIsInstance(version, dict) - self.assertIn("version", version['versions']) - self.assertIn("core", version['versions']) - self.assertIn("gremlin", version['versions']) - self.assertIn("api", version['versions']) + self.assertIn("version", version["versions"]) + self.assertIn("core", version["versions"]) + self.assertIn("gremlin", version["versions"]) + self.assertIn("api", version["versions"]) diff --git a/hugegraph-python-client/src/tests/client_utils.py b/hugegraph-python-client/src/tests/client_utils.py index 63b6d0770..f711072b8 100644 --- a/hugegraph-python-client/src/tests/client_utils.py +++ b/hugegraph-python-client/src/tests/client_utils.py @@ -28,7 +28,11 @@ class ClientUtils: def __init__(self): self.client = PyHugeClient( - url=self.URL, user=self.USERNAME, pwd=self.PASSWORD, graph=self.GRAPH, graphspace=self.GRAPHSPACE + url=self.URL, + user=self.USERNAME, + pwd=self.PASSWORD, + graph=self.GRAPH, + graphspace=self.GRAPHSPACE, ) assert self.client is not None From 7d344301a74665bdce08e42ead57f92c6979bd93 Mon Sep 17 00:00:00 2001 From: Linyu <94553312+weijinglin@users.noreply.github.com> Date: Tue, 30 Sep 2025 10:55:47 +0800 Subject: [PATCH 4/5] refactor(RAG workflow): modularize flows, add streaming, and improve node initialization (#51) --- .../hugegraph_llm/demo/rag_demo/rag_block.py | 223 ++++++++++-------- .../src/hugegraph_llm/flows/common.py | 26 ++ .../flows/rag_flow_graph_only.py | 153 ++++++++++++ .../flows/rag_flow_graph_vector.py | 158 +++++++++++++ .../src/hugegraph_llm/flows/rag_flow_raw.py | 99 ++++++++ .../flows/rag_flow_vector_only.py | 123 ++++++++++ .../src/hugegraph_llm/flows/scheduler.py | 63 +++++ .../src/hugegraph_llm/nodes/base_node.py | 3 + .../nodes/common_node/merge_rerank_node.py | 83 +++++++ .../nodes/document_node/chunk_split.py | 2 +- .../hugegraph_node/commit_to_hugegraph.py | 3 +- .../nodes/hugegraph_node/fetch_graph_data.py | 3 +- .../nodes/hugegraph_node/graph_query_node.py | 93 ++++++++ .../nodes/hugegraph_node/schema.py | 3 +- .../nodes/index_node/build_semantic_index.py | 3 +- .../nodes/index_node/build_vector_index.py | 3 +- .../index_node/gremlin_example_index_query.py | 13 +- .../index_node/semantic_id_query_node.py | 91 +++++++ .../nodes/index_node/vector_query_node.py | 74 ++++++ .../nodes/llm_node/answer_synthesize_node.py | 99 ++++++++ .../nodes/llm_node/extract_info.py | 2 +- .../nodes/llm_node/keyword_extract_node.py | 80 +++++++ .../nodes/llm_node/prompt_generate.py | 2 +- .../nodes/llm_node/schema_build.py | 2 +- .../nodes/llm_node/text2gremlin.py | 10 +- .../src/hugegraph_llm/state/ai_state.py | 106 ++++++++- .../hugegraph_llm/utils/graph_index_utils.py | 115 ++------- 27 files changed, 1411 insertions(+), 224 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index 8f70c34bd..ca36867d9 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -21,12 +21,11 @@ from typing import AsyncGenerator, Literal, Optional, Tuple import gradio as gr +from hugegraph_llm.flows.scheduler import SchedulerSingleton import pandas as pd from gradio.utils import NamedString -from hugegraph_llm.config import huge_settings, llm_settings, prompt, resource_path -from hugegraph_llm.operators.graph_rag_task import RAGPipeline -from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize +from hugegraph_llm.config import resource_path, prompt, llm_settings from hugegraph_llm.utils.decorators import with_task_id from hugegraph_llm.utils.log import log @@ -72,44 +71,51 @@ def rag_answer( gr.Warning("Please select at least one generate mode.") return "", "", "", "" - rag = RAGPipeline() - if vector_search: - rag.query_vector_index() - if graph_search: - rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid( - vector_dis_threshold=vector_dis_threshold, - topk_per_keyword=topk_per_keyword, - ).import_schema(huge_settings.graph_name).query_graphdb( - num_gremlin_generate_example=gremlin_tmpl_num, - gremlin_prompt=gremlin_prompt, - max_graph_items=max_graph_items, - ) - # TODO: add more user-defined search strategies - rag.merge_dedup_rerank( - graph_ratio=graph_ratio, - rerank_method=rerank_method, - near_neighbor_first=near_neighbor_first, - topk_return_results=topk_return_results, - ) - rag.synthesize_answer( - raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt - ) - + scheduler = SchedulerSingleton.get_instance() try: - context = rag.run( - verbose=True, + # Select workflow by mode to avoid fetching the wrong pipeline from the pool + if graph_vector_answer or (graph_only_answer and vector_only_answer): + flow_key = "rag_graph_vector" + elif vector_only_answer: + flow_key = "rag_vector_only" + elif graph_only_answer: + flow_key = "rag_graph_only" + elif raw_answer: + flow_key = "rag_raw" + else: + raise RuntimeError("Unsupported flow type") + + res = scheduler.schedule_flow( + flow_key, query=text, vector_search=vector_search, graph_search=graph_search, + raw_answer=raw_answer, + vector_only_answer=vector_only_answer, + graph_only_answer=graph_only_answer, + graph_vector_answer=graph_vector_answer, + graph_ratio=graph_ratio, + rerank_method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + answer_prompt=answer_prompt, + keywords_extract_prompt=keywords_extract_prompt, + gremlin_tmpl_num=gremlin_tmpl_num, + gremlin_prompt=gremlin_prompt, max_graph_items=max_graph_items, + topk_return_results=topk_return_results, + vector_dis_threshold=vector_dis_threshold, + topk_per_keyword=topk_per_keyword, ) - if context.get("switch_to_bleu"): - gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") + if res.get("switch_to_bleu"): + gr.Warning( + "Online reranker fails, automatically switches to local bleu rerank." + ) return ( - context.get("raw_answer", ""), - context.get("vector_only_answer", ""), - context.get("graph_only_answer", ""), - context.get("graph_vector_answer", ""), + res.get("raw_answer", ""), + res.get("vector_only_answer", ""), + res.get("graph_only_answer", ""), + res.get("graph_vector_answer", ""), ) except ValueError as e: log.critical(e) @@ -187,44 +193,47 @@ async def rag_answer_streaming( yield "", "", "", "" return - rag = RAGPipeline() - if vector_search: - rag.query_vector_index() - if graph_search: - rag.extract_keywords( - extract_template=keywords_extract_prompt - ).keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb( - num_gremlin_generate_example=gremlin_tmpl_num, - gremlin_prompt=gremlin_prompt, - ) - rag.merge_dedup_rerank( - graph_ratio, - rerank_method, - near_neighbor_first, - ) - # rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt) - try: - context = rag.run( - verbose=True, query=text, vector_search=vector_search, graph_search=graph_search - ) - if context.get("switch_to_bleu"): - gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") - answer_synthesize = AnswerSynthesize( + # Select the specific streaming workflow + scheduler = SchedulerSingleton.get_instance() + if graph_vector_answer or (graph_only_answer and vector_only_answer): + flow_key = "rag_graph_vector" + elif vector_only_answer: + flow_key = "rag_vector_only" + elif graph_only_answer: + flow_key = "rag_graph_only" + elif raw_answer: + flow_key = "rag_raw" + else: + raise RuntimeError("Unsupported flow type") + + async for res in scheduler.schedule_stream_flow( + flow_key, + query=text, + vector_search=vector_search, + graph_search=graph_search, raw_answer=raw_answer, vector_only_answer=vector_only_answer, graph_only_answer=graph_only_answer, graph_vector_answer=graph_vector_answer, - prompt_template=answer_prompt, - ) - async for context in answer_synthesize.run_streaming(context): - if context.get("switch_to_bleu"): - gr.Warning("Online reranker fails, automatically switches to local bleu rerank.") + graph_ratio=graph_ratio, + rerank_method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + answer_prompt=answer_prompt, + keywords_extract_prompt=keywords_extract_prompt, + gremlin_tmpl_num=gremlin_tmpl_num, + gremlin_prompt=gremlin_prompt, + ): + if res.get("switch_to_bleu"): + gr.Warning( + "Online reranker fails, automatically switches to local bleu rerank." + ) yield ( - context.get("raw_answer", ""), - context.get("vector_only_answer", ""), - context.get("graph_only_answer", ""), - context.get("graph_vector_answer", ""), + res.get("raw_answer", ""), + res.get("vector_only_answer", ""), + res.get("graph_only_answer", ""), + res.get("graph_vector_answer", ""), ) except ValueError as e: log.critical(e) @@ -242,7 +251,10 @@ def create_rag_block(): with gr.Column(scale=2): # with gr.Blocks().queue(max_size=20, default_concurrency_limit=5): inp = gr.Textbox( - value=prompt.default_question, label="Question", show_copy_button=True, lines=3 + value=prompt.default_question, + label="Question", + show_copy_button=True, + lines=3, ) # TODO: Only support inline formula now. Should support block formula @@ -271,7 +283,10 @@ def create_rag_block(): latex_delimiters=[{"left": "$", "right": "$", "display": False}], ) answer_prompt_input = gr.Textbox( - value=prompt.answer_prompt, label="Query Prompt", show_copy_button=True, lines=7 + value=prompt.answer_prompt, + label="Query Prompt", + show_copy_button=True, + lines=7, ) keywords_extract_prompt_input = gr.Textbox( value=prompt.keywords_extract_prompt, @@ -282,7 +297,9 @@ def create_rag_block(): with gr.Column(scale=1): with gr.Row(): - raw_radio = gr.Radio(choices=[True, False], value=False, label="Basic LLM Answer") + raw_radio = gr.Radio( + choices=[True, False], value=False, label="Basic LLM Answer" + ) vector_only_radio = gr.Radio( choices=[True, False], value=False, label="Vector-only Answer" ) @@ -306,7 +323,9 @@ def toggle_slider(enable): label="Rerank method", ) example_num = gr.Number( - value=-1, label="Template Num (<0 means disable text2gql) ", precision=0 + value=-1, + label="Template Num (<0 means disable text2gql) ", + precision=0, ) graph_ratio = gr.Slider( 0, 1, 0.6, label="Graph Ratio", step=0.1, interactive=False @@ -351,7 +370,7 @@ def toggle_slider(enable): """## 2. (Batch) Back-testing ) > 1. Download the template file & fill in the questions you want to test. > 2. Upload the file & click the button to generate answers. (Preview shows the first 40 lines) - > 3. The answer options are the same as the above RAG/Q&A frame + > 3. The answer options are the same as the above RAG/Q&A frame """ ) tests_df_headers = [ @@ -365,7 +384,9 @@ def toggle_slider(enable): # FIXME: "demo" might conflict with the graph name, it should be modified. answers_path = os.path.join(resource_path, "demo", "questions_answers.xlsx") questions_path = os.path.join(resource_path, "demo", "questions.xlsx") - questions_template_path = os.path.join(resource_path, "demo", "questions_template.xlsx") + questions_template_path = os.path.join( + resource_path, "demo", "questions_template.xlsx" + ) def read_file_to_excel(file: NamedString, line_count: Optional[int] = None): df = None @@ -412,20 +433,23 @@ def several_rag_answer( total_rows = len(df) for index, row in df.iterrows(): question = row.iloc[0] - basic_llm_answer, vector_only_answer, graph_only_answer, graph_vector_answer = ( - rag_answer( - question, - is_raw_answer, - is_vector_only_answer, - is_graph_only_answer, - is_graph_vector_answer, - graph_ratio_ui, - rerank_method_ui, - near_neighbor_first_ui, - custom_related_information_ui, - answer_prompt, - keywords_extract_prompt, - ) + ( + basic_llm_answer, + vector_only_answer, + graph_only_answer, + graph_vector_answer, + ) = rag_answer( + question, + is_raw_answer, + is_vector_only_answer, + is_graph_only_answer, + is_graph_vector_answer, + graph_ratio_ui, + rerank_method_ui, + near_neighbor_first_ui, + custom_related_information_ui, + answer_prompt, + keywords_extract_prompt, ) df.at[index, "Basic LLM Answer"] = basic_llm_answer df.at[index, "Vector-only Answer"] = vector_only_answer @@ -442,12 +466,18 @@ def several_rag_answer( file_types=[".xlsx", ".csv"], label="Questions File (.xlsx & csv)" ) with gr.Column(): - test_template_file = os.path.join(resource_path, "demo", "questions_template.xlsx") + test_template_file = os.path.join( + resource_path, "demo", "questions_template.xlsx" + ) gr.File(value=test_template_file, label="Download Template File") - answer_max_line_count = gr.Number(1, label="Max Lines To Show", minimum=1, maximum=40) + answer_max_line_count = gr.Number( + 1, label="Max Lines To Show", minimum=1, maximum=40 + ) answers_btn = gr.Button("Generate Answer (Batch)", variant="primary") # TODO: Set individual progress bars for dataframe - qa_dataframe = gr.DataFrame(label="Questions & Answers (Preview)", headers=tests_df_headers) + qa_dataframe = gr.DataFrame( + label="Questions & Answers (Preview)", headers=tests_df_headers + ) answers_btn.click( several_rag_answer, inputs=[ @@ -465,6 +495,15 @@ def several_rag_answer( ], outputs=[qa_dataframe, gr.File(label="Download Answered File", min_width=40)], ) - questions_file.change(read_file_to_excel, questions_file, [qa_dataframe, answer_max_line_count]) - answer_max_line_count.change(change_showing_excel, answer_max_line_count, qa_dataframe) - return inp, answer_prompt_input, keywords_extract_prompt_input, custom_related_information + questions_file.change( + read_file_to_excel, questions_file, [qa_dataframe, answer_max_line_count] + ) + answer_max_line_count.change( + change_showing_excel, answer_max_line_count, qa_dataframe + ) + return ( + inp, + answer_prompt_input, + keywords_extract_prompt_input, + custom_related_information, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/common.py b/hugegraph-llm/src/hugegraph_llm/flows/common.py index 4c552626a..e2348466c 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/common.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/common.py @@ -14,8 +14,10 @@ # limitations under the License. from abc import ABC, abstractmethod +from typing import Dict, Any, AsyncGenerator from hugegraph_llm.state.ai_state import WkFlowInput +from hugegraph_llm.utils.log import log class BaseFlow(ABC): @@ -43,3 +45,27 @@ def post_deal(self, *args, **kwargs): Post-processing interface. """ pass + + async def post_deal_stream( + self, pipeline=None + ) -> AsyncGenerator[Dict[str, Any], None]: + """ + Streaming post-processing interface. + Subclasses can override this method as needed. + """ + flow_name = self.__class__.__name__ + if pipeline is None: + yield {"error": "No pipeline provided"} + return + try: + state_json = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info(f"{flow_name} post processing success") + stream_flow = state_json.get("stream_generator") + if stream_flow is None: + yield {"error": "No stream_generator found in workflow state"} + return + async for chunk in stream_flow: + yield chunk + except Exception as e: + log.error(f"{flow_name} post processing failed: {e}") + yield {"error": f"Post processing failed: {str(e)}"} diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py new file mode 100644 index 000000000..5feb3d471 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from typing import Optional, Literal + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode +from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode +from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class RAGGraphOnlyFlow(BaseFlow): + """ + Workflow for graph-only answering (graph_only_answer) + """ + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + vector_search: bool = None, + graph_search: bool = None, + raw_answer: bool = None, + vector_only_answer: bool = None, + graph_only_answer: bool = None, + graph_vector_answer: bool = None, + graph_ratio: float = 0.5, + rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = "", + answer_prompt: Optional[str] = None, + keywords_extract_prompt: Optional[str] = None, + gremlin_tmpl_num: Optional[int] = -1, + gremlin_prompt: Optional[str] = None, + max_graph_items: int = None, + topk_return_results: int = None, + vector_dis_threshold: float = None, + topk_per_keyword: int = None, + **_: dict, + ): + prepared_input.query = query + prepared_input.vector_search = vector_search + prepared_input.graph_search = graph_search + prepared_input.raw_answer = raw_answer + prepared_input.vector_only_answer = vector_only_answer + prepared_input.graph_only_answer = graph_only_answer + prepared_input.graph_vector_answer = graph_vector_answer + prepared_input.gremlin_tmpl_num = gremlin_tmpl_num + prepared_input.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt + prepared_input.max_graph_items = ( + max_graph_items or huge_settings.max_graph_items + ) + prepared_input.topk_per_keyword = ( + topk_per_keyword or huge_settings.topk_per_keyword + ) + prepared_input.topk_return_results = ( + topk_return_results or huge_settings.topk_return_results + ) + prepared_input.rerank_method = rerank_method + prepared_input.near_neighbor_first = near_neighbor_first + prepared_input.keywords_extract_prompt = ( + keywords_extract_prompt or prompt.keywords_extract_prompt + ) + prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt + prepared_input.custom_related_information = custom_related_information + prepared_input.vector_dis_threshold = ( + vector_dis_threshold or huge_settings.vector_dis_threshold + ) + prepared_input.schema = huge_settings.graph_name + + prepared_input.data_json = { + "query": query, + "vector_search": vector_search, + "graph_search": graph_search, + "max_graph_items": max_graph_items or huge_settings.max_graph_items, + } + return + + def build_flow(self, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create nodes and register them with registerGElement + only_keyword_extract_node = KeywordExtractNode() + only_semantic_id_query_node = SemanticIdQueryNode() + only_schema_node = SchemaNode() + only_graph_query_node = GraphQueryNode() + merge_rerank_node = MergeRerankNode() + answer_synthesize_node = AnswerSynthesizeNode() + + pipeline.registerGElement(only_keyword_extract_node, set(), "only_keyword") + pipeline.registerGElement( + only_semantic_id_query_node, {only_keyword_extract_node}, "only_semantic" + ) + pipeline.registerGElement(only_schema_node, set(), "only_schema") + pipeline.registerGElement( + only_graph_query_node, + {only_schema_node, only_semantic_id_query_node}, + "only_graph", + ) + pipeline.registerGElement( + merge_rerank_node, {only_graph_query_node}, "merge_one" + ) + pipeline.registerGElement(answer_synthesize_node, {merge_rerank_node}, "graph") + log.info("RAGGraphOnlyFlow pipeline built successfully") + return pipeline + + def post_deal(self, pipeline=None): + if pipeline is None: + return json.dumps( + {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 + ) + try: + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGGraphOnlyFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } + except Exception as e: + log.error(f"RAGGraphOnlyFlow post processing failed: {e}") + return json.dumps( + {"error": f"Post processing failed: {str(e)}"}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py new file mode 100644 index 000000000..2f4a2bfa2 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from typing import Optional, Literal + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.index_node.vector_query_node import VectorQueryNode +from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode +from hugegraph_llm.nodes.index_node.semantic_id_query_node import SemanticIdQueryNode +from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode +from hugegraph_llm.nodes.hugegraph_node.graph_query_node import GraphQueryNode +from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class RAGGraphVectorFlow(BaseFlow): + """ + Workflow for graph + vector hybrid answering (graph_vector_answer) + """ + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + vector_search: bool = None, + graph_search: bool = None, + raw_answer: bool = None, + vector_only_answer: bool = None, + graph_only_answer: bool = None, + graph_vector_answer: bool = None, + graph_ratio: float = 0.5, + rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = "", + answer_prompt: Optional[str] = None, + keywords_extract_prompt: Optional[str] = None, + gremlin_tmpl_num: Optional[int] = -1, + gremlin_prompt: Optional[str] = None, + max_graph_items: int = None, + topk_return_results: int = None, + vector_dis_threshold: float = None, + topk_per_keyword: int = None, + **_: dict, + ): + prepared_input.query = query + prepared_input.vector_search = vector_search + prepared_input.graph_search = graph_search + prepared_input.raw_answer = raw_answer + prepared_input.vector_only_answer = vector_only_answer + prepared_input.graph_only_answer = graph_only_answer + prepared_input.graph_vector_answer = graph_vector_answer + prepared_input.graph_ratio = graph_ratio + prepared_input.gremlin_tmpl_num = gremlin_tmpl_num + prepared_input.gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt + prepared_input.max_graph_items = ( + max_graph_items or huge_settings.max_graph_items + ) + prepared_input.topk_return_results = ( + topk_return_results or huge_settings.topk_return_results + ) + prepared_input.topk_per_keyword = ( + topk_per_keyword or huge_settings.topk_per_keyword + ) + prepared_input.vector_dis_threshold = ( + vector_dis_threshold or huge_settings.vector_dis_threshold + ) + prepared_input.rerank_method = rerank_method + prepared_input.near_neighbor_first = near_neighbor_first + prepared_input.keywords_extract_prompt = ( + keywords_extract_prompt or prompt.keywords_extract_prompt + ) + prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt + prepared_input.custom_related_information = custom_related_information + prepared_input.schema = huge_settings.graph_name + + prepared_input.data_json = { + "query": query, + "vector_search": vector_search, + "graph_search": graph_search, + "max_graph_items": max_graph_items or huge_settings.max_graph_items, + } + return + + def build_flow(self, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create nodes (registration style consistent with RAGFlow) + vector_query_node = VectorQueryNode() + keyword_extract_node = KeywordExtractNode() + semantic_id_query_node = SemanticIdQueryNode() + schema_node = SchemaNode() + graph_query_node = GraphQueryNode() + merge_rerank_node = MergeRerankNode() + answer_synthesize_node = AnswerSynthesizeNode() + + # Register nodes and their dependencies + pipeline.registerGElement(vector_query_node, set(), "vector") + pipeline.registerGElement(keyword_extract_node, set(), "keyword") + pipeline.registerGElement( + semantic_id_query_node, {keyword_extract_node}, "semantic" + ) + pipeline.registerGElement(schema_node, set(), "schema") + pipeline.registerGElement( + graph_query_node, {schema_node, semantic_id_query_node}, "graph" + ) + pipeline.registerGElement( + merge_rerank_node, {graph_query_node, vector_query_node}, "merge" + ) + pipeline.registerGElement( + answer_synthesize_node, {merge_rerank_node}, "graph_vector" + ) + log.info("RAGGraphVectorFlow pipeline built successfully") + return pipeline + + def post_deal(self, pipeline=None): + if pipeline is None: + return json.dumps( + {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 + ) + try: + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGGraphVectorFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } + except Exception as e: + log.error(f"RAGGraphVectorFlow post processing failed: {e}") + return json.dumps( + {"error": f"Post processing failed: {str(e)}"}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py new file mode 100644 index 000000000..f62e574bb --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from typing import Optional + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class RAGRawFlow(BaseFlow): + """ + Workflow for basic LLM answering only (raw_answer) + """ + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + vector_search: bool = None, + graph_search: bool = None, + raw_answer: bool = None, + vector_only_answer: bool = None, + graph_only_answer: bool = None, + graph_vector_answer: bool = None, + custom_related_information: str = "", + answer_prompt: Optional[str] = None, + max_graph_items: int = None, + **_: dict, + ): + prepared_input.query = query + prepared_input.raw_answer = raw_answer + prepared_input.vector_only_answer = vector_only_answer + prepared_input.graph_only_answer = graph_only_answer + prepared_input.graph_vector_answer = graph_vector_answer + prepared_input.custom_related_information = custom_related_information + prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt + prepared_input.schema = huge_settings.graph_name + + prepared_input.data_json = { + "query": query, + "vector_search": vector_search, + "graph_search": graph_search, + "max_graph_items": max_graph_items or huge_settings.max_graph_items, + } + return + + def build_flow(self, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create nodes and register with registerGElement (no GRegion required) + answer_synthesize_node = AnswerSynthesizeNode() + pipeline.registerGElement(answer_synthesize_node, set(), "raw") + log.info("RAGRawFlow pipeline built successfully") + return pipeline + + def post_deal(self, pipeline=None): + if pipeline is None: + return json.dumps( + {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 + ) + try: + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGRawFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } + except Exception as e: + log.error(f"RAGRawFlow post processing failed: {e}") + return json.dumps( + {"error": f"Post processing failed: {str(e)}"}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py new file mode 100644 index 000000000..c727eacce --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from typing import Optional, Literal + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.nodes.index_node.vector_query_node import VectorQueryNode +from hugegraph_llm.nodes.common_node.merge_rerank_node import MergeRerankNode +from hugegraph_llm.nodes.llm_node.answer_synthesize_node import AnswerSynthesizeNode +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class RAGVectorOnlyFlow(BaseFlow): + """ + Workflow for vector-only answering (vector_only_answer) + """ + + def prepare( + self, + prepared_input: WkFlowInput, + query: str, + vector_search: bool = None, + graph_search: bool = None, + raw_answer: bool = None, + vector_only_answer: bool = None, + graph_only_answer: bool = None, + graph_vector_answer: bool = None, + rerank_method: Literal["bleu", "reranker"] = "bleu", + near_neighbor_first: bool = False, + custom_related_information: str = "", + answer_prompt: Optional[str] = None, + max_graph_items: int = None, + topk_return_results: int = None, + vector_dis_threshold: float = None, + **_: dict, + ): + prepared_input.query = query + prepared_input.vector_search = vector_search + prepared_input.graph_search = graph_search + prepared_input.raw_answer = raw_answer + prepared_input.vector_only_answer = vector_only_answer + prepared_input.graph_only_answer = graph_only_answer + prepared_input.graph_vector_answer = graph_vector_answer + prepared_input.vector_dis_threshold = ( + vector_dis_threshold or huge_settings.vector_dis_threshold + ) + prepared_input.topk_return_results = ( + topk_return_results or huge_settings.topk_return_results + ) + prepared_input.rerank_method = rerank_method + prepared_input.near_neighbor_first = near_neighbor_first + prepared_input.custom_related_information = custom_related_information + prepared_input.answer_prompt = answer_prompt or prompt.answer_prompt + prepared_input.schema = huge_settings.graph_name + + prepared_input.data_json = { + "query": query, + "vector_search": vector_search, + "graph_search": graph_search, + "max_graph_items": max_graph_items or huge_settings.max_graph_items, + } + return + + def build_flow(self, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, **kwargs) + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + # Create nodes (do not use GRegion, use registerGElement for all nodes) + only_vector_query_node = VectorQueryNode() + merge_rerank_node = MergeRerankNode() + answer_synthesize_node = AnswerSynthesizeNode() + + # Register nodes and dependencies, keep naming consistent with original + pipeline.registerGElement(only_vector_query_node, set(), "only_vector") + pipeline.registerGElement( + merge_rerank_node, {only_vector_query_node}, "merge_two" + ) + pipeline.registerGElement(answer_synthesize_node, {merge_rerank_node}, "vector") + log.info("RAGVectorOnlyFlow pipeline built successfully") + return pipeline + + def post_deal(self, pipeline=None): + if pipeline is None: + return json.dumps( + {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 + ) + try: + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGVectorOnlyFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } + except Exception as e: + log.error(f"RAGVectorOnlyFlow post processing failed: {e}") + return json.dumps( + {"error": f"Post processing failed: {str(e)}"}, + ensure_ascii=False, + indent=2, + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py index 3aedbe7f2..5afa1bf8e 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -24,6 +24,11 @@ from hugegraph_llm.flows.get_graph_index_info import GetGraphIndexInfoFlow from hugegraph_llm.flows.build_schema import BuildSchemaFlow from hugegraph_llm.flows.prompt_generate import PromptGenerateFlow +from hugegraph_llm.flows.rag_flow_raw import RAGRawFlow +from hugegraph_llm.flows.rag_flow_vector_only import RAGVectorOnlyFlow +from hugegraph_llm.flows.rag_flow_graph_only import RAGGraphOnlyFlow +from hugegraph_llm.flows.rag_flow_graph_vector import RAGGraphVectorFlow +from hugegraph_llm.state.ai_state import WkFlowInput from hugegraph_llm.utils.log import log from hugegraph_llm.flows.text2gremlin import Text2GremlinFlow @@ -67,6 +72,23 @@ def __init__(self, max_pipeline: int = 10): "manager": GPipelineManager(), "flow": Text2GremlinFlow(), } + # New split rag pipelines + self.pipeline_pool["rag_raw"] = { + "manager": GPipelineManager(), + "flow": RAGRawFlow(), + } + self.pipeline_pool["rag_vector_only"] = { + "manager": GPipelineManager(), + "flow": RAGVectorOnlyFlow(), + } + self.pipeline_pool["rag_graph_only"] = { + "manager": GPipelineManager(), + "flow": RAGGraphOnlyFlow(), + } + self.pipeline_pool["rag_graph_vector"] = { + "manager": GPipelineManager(), + "flow": RAGGraphVectorFlow(), + } self.max_pipeline = max_pipeline # TODO: Implement Agentic Workflow @@ -108,6 +130,47 @@ def schedule_flow(self, flow: str, *args, **kwargs): manager.release(pipeline) return res + async def schedule_stream_flow(self, flow: str, *args, **kwargs): + if flow not in self.pipeline_pool: + raise ValueError(f"Unsupported workflow {flow}") + manager: GPipelineManager = self.pipeline_pool[flow]["manager"] + flow: BaseFlow = self.pipeline_pool[flow]["flow"] + pipeline: GPipeline = manager.fetch() + if pipeline is None: + # call coresponding flow_func to create new workflow + pipeline = flow.build_flow(*args, **kwargs) + try: + pipeline.getGParamWithNoEmpty("wkflow_input").stream = True + status = pipeline.init() + if status.isErr(): + error_msg = f"Error in flow init: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + status = pipeline.run() + if status.isErr(): + error_msg = f"Error in flow execution: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + async for res in flow.post_deal_stream(pipeline): + yield res + finally: + manager.add(pipeline) + else: + try: + # fetch pipeline & prepare input for flow + prepared_input: WkFlowInput = pipeline.getGParamWithNoEmpty( + "wkflow_input" + ) + prepared_input.stream = True + flow.prepare(prepared_input, *args, **kwargs) + status = pipeline.run() + if status.isErr(): + raise RuntimeError(f"Error in flow execution {status.getInfo()}") + async for res in flow.post_deal_stream(pipeline): + yield res + finally: + manager.release(pipeline) + class SchedulerSingleton: _instance = None diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py index 0ea0675c0..f90167305 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py @@ -30,6 +30,9 @@ def node_init(self): Node initialization method, can be overridden by subclasses. Returns a CStatus object indicating whether initialization succeeded. """ + if self.wk_input.data_json is not None: + self.context.assign_from_json(self.wk_input.data_json) + self.wk_input.data_json = None return CStatus() def run(self): diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py new file mode 100644 index 000000000..78f53e231 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.common_op.merge_dedup_rerank import MergeDedupRerank +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.config import huge_settings, llm_settings +from hugegraph_llm.utils.log import log + + +class MergeRerankNode(BaseNode): + """ + Merge and rerank node, responsible for merging vector and graph query results, deduplication and reranking. + """ + + operator: MergeDedupRerank + + def node_init(self): + """ + Initialize the merge and rerank operator. + """ + try: + # Read user configuration parameters from wk_input + embedding = get_embedding(llm_settings) + graph_ratio = self.wk_input.graph_ratio or 0.5 + rerank_method = self.wk_input.rerank_method or "bleu" + near_neighbor_first = self.wk_input.near_neighbor_first or False + custom_related_information = self.wk_input.custom_related_information or "" + topk_return_results = ( + self.wk_input.topk_return_results or huge_settings.topk_return_results + ) + + self.operator = MergeDedupRerank( + embedding=embedding, + graph_ratio=graph_ratio, + method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + topk_return_results=topk_return_results, + ) + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize MergeRerankNode: {e}") + from PyCGraph import CStatus + + return CStatus(-1, f"MergeRerankNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the merge and rerank operation. + """ + try: + # Perform merge, deduplication, and rerank + result = self.operator.run(data_json) + + # Log result statistics + vector_count = len(result.get("vector_result", [])) + graph_count = len(result.get("graph_result", [])) + merged_count = len(result.get("merged_result", [])) + + log.info( + f"Merge and rerank completed: {vector_count} vector results, " + f"{graph_count} graph results, {merged_count} merged results" + ) + + return result + + except Exception as e: + log.error(f"Merge and rerank failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py index 4c5acbe97..f71bd7bd5 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py @@ -37,7 +37,7 @@ def node_init(self): if isinstance(texts, str): texts = [texts] self.chunk_split_op = ChunkSplit(texts, split_type, language) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.chunk_split_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py index b576e8170..a4ebc7092 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/commit_to_hugegraph.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState @@ -29,7 +28,7 @@ def node_init(self): if data_json: self.context.assign_from_json(data_json) self.commit_to_graph_op = Commit2Graph() - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.commit_to_graph_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py index b2434e524..99b428e5e 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState @@ -27,7 +26,7 @@ class FetchGraphDataNode(BaseNode): def node_init(self): self.fetch_graph_data_op = FetchGraphData(get_hg_client()) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.fetch_graph_data_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py new file mode 100644 index 000000000..ae65ccb33 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from typing import Dict, Any +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery +from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.utils.log import log + + +class GraphQueryNode(BaseNode): + """ + Graph query node, responsible for retrieving relevant information from the graph database. + """ + + graph_rag_query: GraphRAGQuery + + def node_init(self): + """ + Initialize the graph query operator. + """ + try: + graph_name = huge_settings.graph_name + if not graph_name: + return CStatus(-1, "graph_name is required in wk_input") + + max_deep = self.wk_input.max_deep or 2 + max_graph_items = ( + self.wk_input.max_graph_items or huge_settings.max_graph_items + ) + max_v_prop_len = self.wk_input.max_v_prop_len or 2048 + max_e_prop_len = self.wk_input.max_e_prop_len or 256 + prop_to_match = self.wk_input.prop_to_match + num_gremlin_generate_example = self.wk_input.gremlin_tmpl_num or -1 + gremlin_prompt = ( + self.wk_input.gremlin_prompt or prompt.gremlin_generate_prompt + ) + + # Initialize GraphRAGQuery operator + self.graph_rag_query = GraphRAGQuery( + max_deep=max_deep, + max_graph_items=max_graph_items, + max_v_prop_len=max_v_prop_len, + max_e_prop_len=max_e_prop_len, + prop_to_match=prop_to_match, + num_gremlin_generate_example=num_gremlin_generate_example, + gremlin_prompt=gremlin_prompt, + ) + + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize GraphQueryNode: {e}") + + return CStatus(-1, f"GraphQueryNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the graph query operation. + """ + try: + # Get the query text from input + query = data_json.get("query", "") + + if not query: + log.warning("No query text provided for graph query") + return data_json + + # Execute the graph query (assuming schema and semantic query have been completed in previous nodes) + graph_result = self.graph_rag_query.run(data_json) + data_json.update(graph_result) + + log.info( + f"Graph query completed, found {len(data_json.get('graph_result', []))} results" + ) + + return data_json + + except Exception as e: + log.error(f"Graph query failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py index 84719d9eb..3face9d63 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py @@ -15,7 +15,6 @@ import json -from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.common_op.check_schema import CheckSchema from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager @@ -59,7 +58,7 @@ def node_init(self): else: log.info("Get schema '%s' from graphdb.", self.schema) self.schema_manager = self._import_schema(from_hugegraph=self.schema) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): log.debug("SchemaNode input state: %s", data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py index ab31fa394..c01cffc91 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_semantic_index.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from hugegraph_llm.config import llm_settings from hugegraph_llm.models.embeddings.init_embedding import get_embedding from hugegraph_llm.nodes.base_node import BaseNode @@ -28,7 +27,7 @@ class BuildSemanticIndexNode(BaseNode): def node_init(self): self.build_semantic_index_op = BuildSemanticIndex(get_embedding(llm_settings)) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.build_semantic_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py index cf2f9b677..1f6a3c75b 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_vector_index.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from hugegraph_llm.config import llm_settings from hugegraph_llm.models.embeddings.init_embedding import get_embedding from hugegraph_llm.nodes.base_node import BaseNode @@ -28,7 +27,7 @@ class BuildVectorIndexNode(BaseNode): def node_init(self): self.build_vector_index_op = BuildVectorIndex(get_embedding(llm_settings)) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): return self.build_vector_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py index eb033d869..e9283598a 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/gremlin_example_index_query.py @@ -19,9 +19,12 @@ from PyCGraph import CStatus +from hugegraph_llm.config import llm_settings from hugegraph_llm.nodes.base_node import BaseNode -from hugegraph_llm.operators.index_op.gremlin_example_index_query import GremlinExampleIndexQuery -from hugegraph_llm.models.embeddings.init_embedding import Embeddings +from hugegraph_llm.operators.index_op.gremlin_example_index_query import ( + GremlinExampleIndexQuery, +) +from hugegraph_llm.models.embeddings.init_embedding import get_embedding class GremlinExampleIndexQueryNode(BaseNode): @@ -29,13 +32,15 @@ class GremlinExampleIndexQueryNode(BaseNode): def node_init(self): # Build operator (index lazy-loading handled in operator) - embedding = Embeddings().get_embedding() + embedding = get_embedding(llm_settings) example_num = getattr(self.wk_input, "example_num", None) if not isinstance(example_num, int): example_num = 2 # Clamp to [0, 10] example_num = max(0, min(10, example_num)) - self.operator = GremlinExampleIndexQuery(embedding=embedding, num_examples=example_num) + self.operator = GremlinExampleIndexQuery( + embedding=embedding, num_examples=example_num + ) return CStatus() def operator_schedule(self, data_json: Dict[str, Any]): diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py new file mode 100644 index 000000000..bf605aa49 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py @@ -0,0 +1,91 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus +from typing import Dict, Any +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.config import huge_settings, llm_settings +from hugegraph_llm.utils.log import log + + +class SemanticIdQueryNode(BaseNode): + """ + Semantic ID query node, responsible for semantic matching based on keywords. + """ + + semantic_id_query: SemanticIdQuery + + def node_init(self): + """ + Initialize the semantic ID query operator. + """ + try: + graph_name = huge_settings.graph_name + if not graph_name: + return CStatus(-1, "graph_name is required in wk_input") + + embedding = get_embedding(llm_settings) + by = self.wk_input.semantic_by or "keywords" + topk_per_keyword = ( + self.wk_input.topk_per_keyword or huge_settings.topk_per_keyword + ) + topk_per_query = self.wk_input.topk_per_query or 10 + vector_dis_threshold = ( + self.wk_input.vector_dis_threshold or huge_settings.vector_dis_threshold + ) + + # Initialize the semantic ID query operator + self.semantic_id_query = SemanticIdQuery( + embedding=embedding, + by=by, + topk_per_keyword=topk_per_keyword, + topk_per_query=topk_per_query, + vector_dis_threshold=vector_dis_threshold, + ) + + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize SemanticIdQueryNode: {e}") + + return CStatus(-1, f"SemanticIdQueryNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the semantic ID query operation. + """ + try: + # Get the query text and keywords from input + query = data_json.get("query", "") + keywords = data_json.get("keywords", []) + + if not query and not keywords: + log.warning("No query text or keywords provided for semantic query") + return data_json + + # Perform the semantic query + semantic_result = self.semantic_id_query.run(data_json) + + match_vids = semantic_result.get("match_vids", []) + log.info( + f"Semantic query completed, found {len(match_vids)} matching vertex IDs" + ) + + return semantic_result + + except Exception as e: + log.error(f"Semantic query failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py new file mode 100644 index 000000000..48b50acf3 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from hugegraph_llm.config import llm_settings +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.vector_index_query import VectorIndexQuery +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.utils.log import log + + +class VectorQueryNode(BaseNode): + """ + Vector query node, responsible for retrieving relevant documents from the vector index + """ + + operator: VectorIndexQuery + + def node_init(self): + """ + Initialize the vector query operator + """ + try: + # 从 wk_input 中读取用户配置参数 + embedding = get_embedding(llm_settings) + max_items = ( + self.wk_input.max_items if self.wk_input.max_items is not None else 3 + ) + + self.operator = VectorIndexQuery(embedding=embedding, topk=max_items) + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize VectorQueryNode: {e}") + from PyCGraph import CStatus + + return CStatus(-1, f"VectorQueryNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the vector query operation + """ + try: + # Get the query text from input + query = data_json.get("query", "") + if not query: + log.warning("No query text provided for vector query") + return data_json + + # Perform the vector query + result = self.operator.run({"query": query}) + + # Update the state + data_json.update(result) + log.info( + f"Vector query completed, found {len(result.get('vector_result', []))} results" + ) + + return data_json + + except Exception as e: + log.error(f"Vector query failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py new file mode 100644 index 000000000..22b970b4a --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize +from hugegraph_llm.utils.log import log + + +class AnswerSynthesizeNode(BaseNode): + """ + Answer synthesis node, responsible for generating the final answer based on retrieval results. + """ + + operator: AnswerSynthesize + + def node_init(self): + """ + Initialize the answer synthesis operator. + """ + try: + prompt_template = self.wk_input.answer_prompt + raw_answer = self.wk_input.raw_answer or False + vector_only_answer = self.wk_input.vector_only_answer or False + graph_only_answer = self.wk_input.graph_only_answer or False + graph_vector_answer = self.wk_input.graph_vector_answer or False + + self.operator = AnswerSynthesize( + prompt_template=prompt_template, + raw_answer=raw_answer, + vector_only_answer=vector_only_answer, + graph_only_answer=graph_only_answer, + graph_vector_answer=graph_vector_answer, + ) + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize AnswerSynthesizeNode: {e}") + from PyCGraph import CStatus + + return CStatus(-1, f"AnswerSynthesizeNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the answer synthesis operation. + """ + try: + if self.getGParamWithNoEmpty("wkflow_input").stream: + # Streaming mode: return a generator for streaming output + data_json["stream_generator"] = self.operator.run_streaming(data_json) + return data_json + else: + # Non-streaming mode: execute answer synthesis + result = self.operator.run(data_json) + + # Record the types of answers generated + answer_types = [] + if result.get("raw_answer"): + answer_types.append("raw") + if result.get("vector_only_answer"): + answer_types.append("vector_only") + if result.get("graph_only_answer"): + answer_types.append("graph_only") + if result.get("graph_vector_answer"): + answer_types.append("graph_vector") + + log.info( + f"Answer synthesis completed for types: {', '.join(answer_types)}" + ) + + # Print enabled answer types according to self.wk_input configuration + wk_input_types = [] + if getattr(self.wk_input, "raw_answer", False): + wk_input_types.append("raw") + if getattr(self.wk_input, "vector_only_answer", False): + wk_input_types.append("vector_only") + if getattr(self.wk_input, "graph_only_answer", False): + wk_input_types.append("graph_only") + if getattr(self.wk_input, "graph_vector_answer", False): + wk_input_types.append("graph_vector") + log.info( + f"Enabled answer types according to wk_input config: {', '.join(wk_input_types)}" + ) + return result + + except Exception as e: + log.error(f"Answer synthesis failed: {e}") + return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py index 8bceed804..628765f58 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py @@ -43,7 +43,7 @@ def node_init(self): self.property_graph_extract = PropertyGraphExtract(llm, example_prompt) else: return CStatus(-1, f"Unsupported extract_type: {extract_type}") - return CStatus() + return super().node_init() def operator_schedule(self, data_json): if self.extract_type == "triples": diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py new file mode 100644 index 000000000..76fc06eb3 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from PyCGraph import CStatus + +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract +from hugegraph_llm.utils.log import log + + +class KeywordExtractNode(BaseNode): + operator: KeywordExtract + + """ + Keyword extraction node, responsible for extracting keywords from query text. + """ + + def node_init(self): + """ + Initialize the keyword extraction operator. + """ + try: + max_keywords = ( + self.wk_input.max_keywords + if self.wk_input.max_keywords is not None + else 5 + ) + language = ( + self.wk_input.language + if self.wk_input.language is not None + else "english" + ) + extract_template = self.wk_input.keywords_extract_prompt + + self.operator = KeywordExtract( + text=self.wk_input.query, + max_keywords=max_keywords, + language=language, + extract_template=extract_template, + ) + return super().node_init() + except Exception as e: + log.error(f"Failed to initialize KeywordExtractNode: {e}") + return CStatus(-1, f"KeywordExtractNode initialization failed: {e}") + + def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute the keyword extraction operation. + """ + try: + # Perform keyword extraction + result = self.operator.run(data_json) + if "keywords" not in result: + log.warning("Keyword extraction result missing 'keywords' field") + result["keywords"] = [] + + log.info(f"Extracted keywords: {result.get('keywords', [])}") + + return result + + except Exception as e: + log.error(f"Keyword extraction failed: {e}") + # Add error flag to indicate failure + error_result = data_json.copy() + error_result["error"] = str(e) + error_result["keywords"] = [] + return error_result diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py index 317f9e6ac..8c49994fd 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/prompt_generate.py @@ -50,7 +50,7 @@ def node_init(self): "example_name": self.wk_input.example_name, } self.context.assign_from_json(context) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): """ diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py index 7df2e68e7..408adb10a 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py @@ -75,7 +75,7 @@ def node_init(self): } self.context.assign_from_json(_context_payload) - return CStatus() + return super().node_init() def operator_schedule(self, data_json): try: diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py index ffbafbaf4..a36831526 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py @@ -22,13 +22,15 @@ from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.config import prompt as prompt_cfg +from hugegraph_llm.models.llms.init_llm import get_text2gql_llm +from hugegraph_llm.config import llm_settings, prompt as prompt_cfg def _stable_schema_string(state_json: Dict[str, Any]) -> str: if "simple_schema" in state_json and state_json["simple_schema"] is not None: - return json.dumps(state_json["simple_schema"], ensure_ascii=False, sort_keys=True) + return json.dumps( + state_json["simple_schema"], ensure_ascii=False, sort_keys=True + ) if "schema" in state_json and state_json["schema"] is not None: return json.dumps(state_json["schema"], ensure_ascii=False, sort_keys=True) return "" @@ -39,7 +41,7 @@ class Text2GremlinNode(BaseNode): def node_init(self): # Select LLM - llm = LLMs().get_text2gql_llm() + llm = get_text2gql_llm(llm_settings) # Serialize schema deterministically state_json = self.context.to_json() schema_str = _stable_schema_string(state_json) diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py index f941098b1..3a6fd3c1c 100644 --- a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -24,7 +24,6 @@ class WkFlowInput(GParam): split_type: str = None # split type used by ChunkSplit Node example_prompt: str = None # need by graph information extract schema: str = None # Schema information requeired by SchemaNode - graph_name: str = None data_json = None extract_type = None query_examples = None @@ -34,11 +33,45 @@ class WkFlowInput(GParam): scenario: str = None # Scenario description example_name: str = None # Example name # Fields for Text2Gremlin - query: str = None example_num: int = None gremlin_prompt: str = None requested_outputs: Optional[List[str]] = None + # RAG Flow related fields + query: str = None # User query for RAG + vector_search: bool = None # Enable vector search + graph_search: bool = None # Enable graph search + raw_answer: bool = None # Return raw answer + vector_only_answer: bool = None # Vector only answer mode + graph_only_answer: bool = None # Graph only answer mode + graph_vector_answer: bool = None # Combined graph and vector answer + graph_ratio: float = None # Graph ratio for merging + rerank_method: str = None # Reranking method + near_neighbor_first: bool = None # Near neighbor first flag + custom_related_information: str = None # Custom related information + answer_prompt: str = None # Answer generation prompt + keywords_extract_prompt: str = None # Keywords extraction prompt + gremlin_tmpl_num: int = None # Gremlin template number + gremlin_prompt: str = None # Gremlin generation prompt + max_graph_items: int = None # Maximum graph items + topk_return_results: int = None # Top-k return results + vector_dis_threshold: float = None # Vector distance threshold + topk_per_keyword: int = None # Top-k per keyword + max_keywords: int = None + max_items: int = None + + # Semantic query related fields + semantic_by: str = None # Semantic query method + topk_per_query: int = None # Top-k per query + + # Graph query related fields + max_deep: int = None # Maximum depth for graph traversal + max_v_prop_len: int = None # Maximum vertex property length + max_e_prop_len: int = None # Maximum edge property length + prop_to_match: str = None # Property to match + + stream: bool = None # used for recognize stream mode + def reset(self, _: CStatus) -> None: self.texts = None self.language = None @@ -55,10 +88,40 @@ def reset(self, _: CStatus) -> None: self.scenario = None self.example_name = None # Text2Gremlin related configuration - self.query = None self.example_num = None self.gremlin_prompt = None self.requested_outputs = None + # RAG Flow related fields + self.query = None + self.vector_search = None + self.graph_search = None + self.raw_answer = None + self.vector_only_answer = None + self.graph_only_answer = None + self.graph_vector_answer = None + self.graph_ratio = None + self.rerank_method = None + self.near_neighbor_first = None + self.custom_related_information = None + self.answer_prompt = None + self.keywords_extract_prompt = None + self.gremlin_tmpl_num = None + self.gremlin_prompt = None + self.max_graph_items = None + self.topk_return_results = None + self.vector_dis_threshold = None + self.topk_per_keyword = None + self.max_keywords = None + self.max_items = None + # Semantic query related fields + self.semantic_by = None + self.topk_per_query = None + # Graph query related fields + self.max_deep = None + self.max_v_prop_len = None + self.max_e_prop_len = None + self.prop_to_match = None + self.stream = None class WkFlowState(GParam): @@ -83,6 +146,17 @@ class WkFlowState(GParam): template_exec_res: Optional[Any] = None raw_exec_res: Optional[Any] = None + match_vids = None + vector_result = None + graph_result = None + + raw_answer: str = None + vector_only_answer: str = None + graph_only_answer: str = None + graph_vector_answer: str = None + + merged_result = None + def setup(self): self.schema = None self.simple_schema = None @@ -90,7 +164,7 @@ def setup(self): self.edges = None self.vertices = None self.triples = None - self.call_count = 0 + self.call_count = None self.keywords = None self.vector_result = None @@ -99,12 +173,20 @@ def setup(self): self.generated_extract_prompt = None # Text2Gremlin results reset - self.match_result = [] - self.result = "" - self.raw_result = "" - self.template_exec_res = "" - self.raw_exec_res = "" + self.match_result = None + self.result = None + self.raw_result = None + self.template_exec_res = None + self.raw_exec_res = None + self.raw_answer = None + self.vector_only_answer = None + self.graph_only_answer = None + self.graph_vector_answer = None + + self.vector_result = None + self.graph_result = None + self.merged_result = None return CStatus() def to_json(self): @@ -116,7 +198,11 @@ def to_json(self): dict: A dictionary containing non-None instance members and their serialized values. """ # Only export instance attributes (excluding methods and class attributes) whose values are not None - return {k: v for k, v in self.__dict__.items() if not k.startswith("_") and v is not None} + return { + k: v + for k, v in self.__dict__.items() + if not k.startswith("_") and v is not None + } # Implement a method that assigns keys from data_json as WkFlowState member variables def assign_from_json(self, data_json: dict): diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index 7b870033a..3f527f2fa 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -44,30 +44,17 @@ def get_graph_index_info(): raise gr.Error(str(e)) -def get_graph_index_info_old(): - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - graph_summary_info = builder.fetch_graph_data().run() - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) - index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) - filename_prefix = get_filename_prefix( - llm_settings.embedding_type, getattr(builder.embedding, "model_name", None) - ) - vector_index = VectorIndex.from_index_file(index_dir, filename_prefix) - graph_summary_info["vid_index"] = { - "embed_dim": vector_index.index.d, - "num_vectors": vector_index.index.ntotal, - "num_vids": len(vector_index.properties), - } - return json.dumps(graph_summary_info, ensure_ascii=False, indent=2) - - def clean_all_graph_index(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) filename_prefix = get_filename_prefix( llm_settings.embedding_type, getattr(Embeddings().get_embedding(), "model_name", None), ) - VectorIndex.clean(str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix) + VectorIndex.clean( + str(os.path.join(resource_path, folder_name, "graph_vids")), filename_prefix + ) VectorIndex.clean( str(os.path.join(resource_path, folder_name, "gremlin_examples")), filename_prefix, @@ -99,14 +86,18 @@ def parse_schema(schema: str, builder: KgBuilder) -> Optional[str]: def extract_graph_origin(input_file, input_text, schema, example_prompt) -> str: texts = read_documents(input_file, input_text) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder( + LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() + ) if not schema: return "ERROR: please input with correct schema/format." error_message = parse_schema(schema, builder) if error_message: return error_message - builder.chunk_split(texts, "document", "zh").extract_info(example_prompt, "property_graph") + builder.chunk_split(texts, "document", "zh").extract_info( + example_prompt, "property_graph" + ) try: context = builder.run() @@ -155,20 +146,6 @@ def update_vid_embedding(): raise gr.Error(str(e)) -def update_vid_embedding_old(): - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - builder.fetch_graph_data().build_vertex_id_semantic_index() - log.debug("Operators: %s", builder.operators) - try: - context = builder.run() - removed_num = context["removed_vid_vector_num"] - added_num = context["added_vid_vector_num"] - return f"Removed {removed_num} vectors, added {added_num} vectors." - except Exception as e: # pylint: disable=broad-exception-caught - log.error(e) - raise gr.Error(str(e)) - - def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: scheduler = SchedulerSingleton.get_instance() @@ -181,73 +158,11 @@ def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: return data -def import_graph_data_old(data: str, schema: str) -> Union[str, Dict[str, Any]]: - try: - data_json = json.loads(data.strip()) - log.debug("Import graph data: %s", data) - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - if schema: - error_message = parse_schema(schema, builder) - if error_message: - return error_message - - context = builder.commit_to_hugegraph().run(data_json) - gr.Info("Import graph data successfully!") - print(context) - return json.dumps(context, ensure_ascii=False, indent=2) - except Exception as e: # pylint: disable=W0718 - log.error(e) - traceback.print_exc() - # Note: can't use gr.Error here - gr.Warning(str(e) + " Please check the graph data format/type carefully.") - return data - - def build_schema(input_text, query_example, few_shot): scheduler = SchedulerSingleton.get_instance() try: - return scheduler.schedule_flow("build_schema", input_text, query_example, few_shot) + return scheduler.schedule_flow( + "build_schema", input_text, query_example, few_shot + ) except (TypeError, ValueError) as e: raise gr.Error(f"Schema generation failed: {e}") - - -def build_schema_old(input_text, query_example, few_shot): - context = { - "raw_texts": [input_text] if input_text else [], - "query_examples": [], - "few_shot_schema": {}, - } - - if few_shot: - try: - context["few_shot_schema"] = json.loads(few_shot) - except json.JSONDecodeError as e: - raise gr.Error(f"Few Shot Schema is not in a valid JSON format: {e}") from e - - if query_example: - try: - parsed_examples = json.loads(query_example) - # Validate and retain the description and gremlin fields - context["query_examples"] = [ - { - "description": ex.get("description", ""), - "gremlin": ex.get("gremlin", ""), - } - for ex in parsed_examples - if isinstance(ex, dict) and "description" in ex and "gremlin" in ex - ] - except json.JSONDecodeError as e: - raise gr.Error(f"Query Examples is not in a valid JSON format: {e}") from e - - builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) - try: - schema = builder.build_schema().run(context) - except Exception as e: - log.error("Failed to generate schema: %s", e) - raise gr.Error(f"Schema generation failed: {e}") from e - try: - formatted_schema = json.dumps(schema, ensure_ascii=False, indent=2) - return formatted_schema - except (TypeError, ValueError) as e: - log.error("Failed to format schema: %s", e) - return str(schema) From af82d914c61ffca9664406cf4e0f8d0222686816 Mon Sep 17 00:00:00 2001 From: jinglinwei Date: Mon, 13 Oct 2025 01:01:47 +0800 Subject: [PATCH 5/5] refactor: port batch build gremlin examples & delete some doc related to Pipeline(old design) & refactor some operator's design and implementation & code format --- README.md | 36 -- hugegraph-llm/README.md | 155 +++--- hugegraph-llm/pyproject.toml | 6 +- .../src/hugegraph_llm/api/admin_api.py | 2 +- .../src/hugegraph_llm/api/rag_api.py | 72 ++- .../hugegraph_llm/demo/rag_demo/rag_block.py | 36 +- .../demo/rag_demo/text2gremlin_block.py | 158 ++---- .../demo/rag_demo/vector_graph_block.py | 59 ++- .../src/hugegraph_llm/flows/__init__.py | 18 + .../flows/build_example_index.py | 62 +++ .../src/hugegraph_llm/flows/build_schema.py | 16 +- .../hugegraph_llm/flows/build_vector_index.py | 16 +- .../src/hugegraph_llm/flows/common.py | 29 +- .../flows/get_graph_index_info.py | 16 +- .../src/hugegraph_llm/flows/graph_extract.py | 25 +- .../hugegraph_llm/flows/import_graph_data.py | 8 +- .../hugegraph_llm/flows/prompt_generate.py | 19 +- .../flows/rag_flow_graph_only.py | 110 ++-- .../flows/rag_flow_graph_vector.py | 55 +- .../src/hugegraph_llm/flows/rag_flow_raw.py | 49 +- .../flows/rag_flow_vector_only.py | 53 +- .../src/hugegraph_llm/flows/scheduler.py | 118 +++-- .../src/hugegraph_llm/flows/text2gremlin.py | 14 +- .../flows/update_vid_embeddings.py | 18 +- .../src/hugegraph_llm/flows/utils.py | 34 -- .../models/embeddings/init_embedding.py | 28 +- .../src/hugegraph_llm/models/llms/init_llm.py | 96 ++-- .../src/hugegraph_llm/nodes/base_node.py | 43 +- .../nodes/common_node/merge_rerank_node.py | 14 +- .../nodes/document_node/chunk_split.py | 2 +- .../nodes/hugegraph_node/fetch_graph_data.py | 9 +- .../nodes/hugegraph_node/graph_query_node.py | 486 ++++++++++++++++-- .../nodes/hugegraph_node/schema.py | 26 +- .../index_node/build_gremlin_example_index.py | 43 ++ .../index_node/semantic_id_query_node.py | 92 ++-- .../nodes/index_node/vector_query_node.py | 27 +- .../nodes/llm_node/answer_synthesize_node.py | 103 ++-- .../nodes/llm_node/extract_info.py | 3 +- .../nodes/llm_node/keyword_extract_node.py | 39 +- .../nodes/llm_node/schema_build.py | 10 +- .../nodes/llm_node/text2gremlin.py | 18 +- hugegraph-llm/src/hugegraph_llm/nodes/util.py | 29 +- .../operators/gremlin_generate_task.py | 81 --- .../hugegraph_op/commit_to_hugegraph.py | 59 ++- .../operators/hugegraph_op/graph_rag_query.py | 455 ---------------- .../index_op/build_semantic_index.py | 24 +- .../operators/kg_construction_task.py | 120 ----- .../operators/llm_op/keyword_extract.py | 58 +-- .../{graph_rag_task.py => operator_list.py} | 233 +++++---- .../src/hugegraph_llm/operators/util.py | 27 - .../src/hugegraph_llm/state/ai_state.py | 186 ++++--- .../hugegraph_llm/utils/graph_index_utils.py | 96 ++-- .../hugegraph_llm/utils/vector_index_utils.py | 19 +- hugegraph-ml/pyproject.toml | 2 +- hugegraph-python-client/pyproject.toml | 2 +- pyproject.toml | 2 +- scripts/build_llm_image.sh | 2 +- .../hugegraph-llm/fixed_flow/design.md | 3 +- .../hugegraph-llm/fixed_flow/requirements.md | 0 .../hugegraph-llm/fixed_flow/tasks.md | 0 style/pylint.conf | 4 +- vermeer-python-client/pyproject.toml | 4 +- 62 files changed, 1760 insertions(+), 1869 deletions(-) create mode 100644 hugegraph-llm/src/hugegraph_llm/flows/build_example_index.py delete mode 100644 hugegraph-llm/src/hugegraph_llm/flows/utils.py create mode 100644 hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_gremlin_example_index.py delete mode 100644 hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py delete mode 100644 hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py delete mode 100644 hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py rename hugegraph-llm/src/hugegraph_llm/operators/{graph_rag_task.py => operator_list.py} (54%) delete mode 100644 hugegraph-llm/src/hugegraph_llm/operators/util.py mode change 100644 => 100755 scripts/build_llm_image.sh rename {.vibedev/spec => spec}/hugegraph-llm/fixed_flow/design.md (99%) rename {.vibedev/spec => spec}/hugegraph-llm/fixed_flow/requirements.md (100%) rename {.vibedev/spec => spec}/hugegraph-llm/fixed_flow/tasks.md (100%) diff --git a/README.md b/README.md index 14f02ca1c..a495968ec 100644 --- a/README.md +++ b/README.md @@ -75,42 +75,6 @@ python -m hugegraph_llm.demo.rag_demo.app > [!NOTE] > Examples assume you've activated the virtual environment with `source .venv/bin/activate` -#### GraphRAG - Question Answering - -```python -from hugegraph_llm.operators.graph_rag_task import RAGPipeline - -# Initialize RAG pipeline -graph_rag = RAGPipeline() - -# Ask questions about your graph -result = (graph_rag - .extract_keywords(text="Tell me about Al Pacino.") - .keywords_to_vid() - .query_graphdb(max_deep=2, max_graph_items=30) - .merge_dedup_rerank() - .synthesize_answer() - .run()) -``` - -#### Knowledge Graph Construction - -```python -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.kg_construction_task import KgBuilder - -# Build KG from text -TEXT = "Your text content here..." -builder = KgBuilder(LLMs().get_chat_llm()) - -(builder - .import_schema(from_hugegraph="hugegraph") - .chunk_split(TEXT) - .extract_info(extract_type="property_graph") - .commit_to_hugegraph() - .run()) -``` - #### Graph Machine Learning ```bash diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index 526320d4a..8b7e15c50 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -89,14 +89,14 @@ curl -LsSf https://astral.sh/uv/install.sh | sh # 3. Clone and setup project git clone https://github.com/apache/incubator-hugegraph-ai.git -cd incubator-hugegraph-ai/hugegraph-llm +cd incubator-hugegraph-ai # Configure environment (see config.md for detailed options), .env will auto create if not exists # 4. Install dependencies and activate environment # NOTE: If download is slow, uncomment mirror lines in ../pyproject.toml or use: uv config --global index.url https://pypi.tuna.tsinghua.edu.cn/simple # Or create local uv.toml with mirror settings to avoid git diff (see uv.toml example in root) -uv sync # Automatically creates .venv and installs dependencies +uv sync --extra llm # Automatically creates .venv and installs dependencies source .venv/bin/activate # Activate once - all commands below assume this environment # 5. Launch RAG demo @@ -146,84 +146,6 @@ Use the Gradio interface for visual knowledge graph building: ![Knowledge Graph Builder](https://hugegraph.apache.org/docs/images/gradio-kg.png) -#### Programmatic Construction - -Build knowledge graphs with code using the `KgBuilder` class: - -```python -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.kg_construction_task import KgBuilder - -# Initialize and chain operations -TEXT = "Your input text here..." -builder = KgBuilder(LLMs().get_chat_llm()) - -( - builder - .import_schema(from_hugegraph="talent_graph").print_result() - .chunk_split(TEXT).print_result() - .extract_info(extract_type="property_graph").print_result() - .commit_to_hugegraph() - .run() -) -``` - -**Pipeline Workflow:** - -```mermaid -graph LR - A[Import Schema] --> B[Chunk Split] - B --> C[Extract Info] - C --> D[Commit to HugeGraph] - D --> E[Execute Pipeline] - - style A fill:#fff2cc - style B fill:#d5e8d4 - style C fill:#dae8fc - style D fill:#f8cecc - style E fill:#e1d5e7 -``` - -### Graph-Enhanced RAG - -Leverage HugeGraph for retrieval-augmented generation: - -```python -from hugegraph_llm.operators.graph_rag_task import RAGPipeline - -# Initialize RAG pipeline -graph_rag = RAGPipeline() - -# Execute RAG workflow -( - graph_rag - .extract_keywords(text="Tell me about Al Pacino.") - .keywords_to_vid() - .query_graphdb(max_deep=2, max_graph_items=30) - .merge_dedup_rerank() - .synthesize_answer(vector_only_answer=False, graph_only_answer=True) - .run(verbose=True) -) -``` - -**RAG Pipeline Flow:** - -```mermaid -graph TD - A[User Query] --> B[Extract Keywords] - B --> C[Match Graph Nodes] - C --> D[Retrieve Graph Context] - D --> E[Rerank Results] - E --> F[Generate Answer] - - style A fill:#e3f2fd - style B fill:#f3e5f5 - style C fill:#e8f5e8 - style D fill:#fff3e0 - style E fill:#fce4ec - style F fill:#e0f2f1 -``` - ## 🔧 Configuration After running the demo, configuration files are automatically generated: @@ -248,6 +170,79 @@ The system supports both English and Chinese prompts. To switch languages: **LLM Provider Support**: This project uses [LiteLLM](https://docs.litellm.ai/docs/providers) for multi-provider LLM support. +### Programmatic Examples (new workflow engine) + +If you previously used high-level classes like `RAGPipeline` or `KgBuilder`, the project now exposes stable flows through the `Scheduler` API. Use `SchedulerSingleton.get_instance().schedule_flow(...)` to invoke workflows programmatically. Below are concise, working examples that match the new architecture. + +1) RAG (graph-only) query example + +```python +from hugegraph_llm.flows.scheduler import SchedulerSingleton + +scheduler = SchedulerSingleton.get_instance() +res = scheduler.schedule_flow( + "rag_graph_only", + query="Tell me about Al Pacino.", + graph_only_answer=True, + vector_only_answer=False, + raw_answer=False, + gremlin_tmpl_num=-1, + gremlin_prompt=None, +) + +print(res.get("graph_only_answer")) +``` + +2) RAG (vector-only) query example + +```python +from hugegraph_llm.flows.scheduler import SchedulerSingleton + +scheduler = SchedulerSingleton.get_instance() +res = scheduler.schedule_flow( + "rag_vector_only", + query="Summarize the career of Ada Lovelace.", + vector_only_answer=True, + vector_search=True +) + +print(res.get("vector_only_answer")) +``` + +3) Text -> Gremlin (text2gremlin) example + +```python +from hugegraph_llm.flows.scheduler import SchedulerSingleton + +scheduler = SchedulerSingleton.get_instance() +response = scheduler.schedule_flow( + "text2gremlin", + "find people who worked with Alan Turing", + 2, # example_num + "hugegraph", # schema_input (graph name or schema) + None, # gremlin_prompt_input (optional) + ["template_gremlin", "raw_gremlin"], +) + +print(response.get("template_gremlin")) +``` + +4) Build example index (used by text2gremlin examples) + +```python +from hugegraph_llm.flows.scheduler import SchedulerSingleton + +examples = [{"id": "natural language query", "gremlin": "g.V().hasLabel('person').valueMap()"}] +res = SchedulerSingleton.get_instance().schedule_flow("build_examples_index", examples) +print(res) +``` + +### Migration guide: RAGPipeline / KgBuilder → Scheduler flows + +Why the change: the internal execution engine was refactored to a pipeline-based scheduler (GPipeline + GPipelineManager). The scheduler provides a stable entrypoint while keeping flow implementations modular. + +If you need help migrating a specific snippet, open a PR or issue and include the old code — we can provide a targeted conversion. + ## 🤖 Developer Guidelines > [!IMPORTANT] > **For developers contributing to hugegraph-llm with AI coding assistance:** diff --git a/hugegraph-llm/pyproject.toml b/hugegraph-llm/pyproject.toml index 2b0f29ace..b3894d7e5 100644 --- a/hugegraph-llm/pyproject.toml +++ b/hugegraph-llm/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "hugegraph-llm" -version = "1.5.0" +version = "1.7.0" description = "A tool for the implementation and research related to large language models." authors = [ { name = "Apache HugeGraph Contributors", email = "dev@hugegraph.apache.org" }, @@ -89,4 +89,6 @@ allow-direct-references = true [tool.uv.sources] hugegraph-python-client = { workspace = true } -pycgraph = { git = "https://github.com/ChunelFeng/CGraph.git", subdirectory = "python", rev = "main", marker = "sys_platform == 'linux'" } +# We encountered a bug in PyCGraph's latest release version, so we're using a specific commit from the main branch (without the bug) as the project dependency. +# TODO: Replace this command in the future when a new PyCGraph release version (after 3.1.2) is available. +pycgraph = { git = "https://github.com/ChunelFeng/CGraph.git", subdirectory = "python", rev = "248bfcfeddfa2bc23a1d585a3925c71189dba6cc"} diff --git a/hugegraph-llm/src/hugegraph_llm/api/admin_api.py b/hugegraph-llm/src/hugegraph_llm/api/admin_api.py index 4c192c29c..109da4a99 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/admin_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/admin_api.py @@ -31,7 +31,7 @@ def admin_http_api(router: APIRouter, log_stream): @router.post("/logs", status_code=status.HTTP_200_OK) async def log_stream_api(req: LogStreamRequest): if admin_settings.admin_token != req.admin_token: - raise generate_response( + raise generate_response( # pylint: disable=raising-bad-type RAGResponse( status_code=status.HTTP_403_FORBIDDEN, # pylint: disable=E0702 message="Invalid admin_token", diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index 356176e4e..ca29cb9ab 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -31,9 +31,8 @@ from hugegraph_llm.api.models.rag_response import RAGResponse from hugegraph_llm.config import huge_settings from hugegraph_llm.config import llm_settings, prompt +from hugegraph_llm.utils.graph_index_utils import get_vertex_details from hugegraph_llm.utils.log import log -from hugegraph_llm.flows.scheduler import SchedulerSingleton - # pylint: disable=too-many-statements @@ -51,6 +50,13 @@ def rag_http_api( def rag_answer_api(req: RAGRequest): set_graph_config(req) + # Basic parameter validation: empty query => 400 + if not req.query or not str(req.query).strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Query must not be empty.", + ) + result = rag_answer_func( text=req.query, raw_answer=req.raw_answer, @@ -68,7 +74,8 @@ def rag_answer_api(req: RAGRequest): # Keep prompt params in the end custom_related_information=req.custom_priority_info, answer_prompt=req.answer_prompt or prompt.answer_prompt, - keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt, + keywords_extract_prompt=req.keywords_extract_prompt + or prompt.keywords_extract_prompt, gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, ) # TODO: we need more info in the response for users to understand the query logic @@ -77,7 +84,8 @@ def rag_answer_api(req: RAGRequest): **{ key: value for key, value in zip( - ["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result + ["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], + result, ) if getattr(req, key) }, @@ -96,6 +104,13 @@ def graph_rag_recall_api(req: GraphRAGRequest): try: set_graph_config(req) + # Basic parameter validation: empty query => 400 + if not req.query or not str(req.query).strip(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Query must not be empty.", + ) + result = graph_rag_recall_func( query=req.query, max_graph_items=req.max_graph_items, @@ -111,12 +126,7 @@ def graph_rag_recall_api(req: GraphRAGRequest): ) if req.get_vertex_only: - from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery - - graph_rag = GraphRAGQuery() - graph_rag.init_client(result) - vertex_details = graph_rag.get_vertex_details(result["match_vids"]) - + vertex_details = get_vertex_details(result["match_vids"], result) if vertex_details: result["match_vids"] = vertex_details @@ -136,7 +146,9 @@ def graph_rag_recall_api(req: GraphRAGRequest): except TypeError as e: log.error("TypeError in graph_rag_recall_api: %s", e) - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail=str(e) + ) from e except Exception as e: log.error("Unexpected error occurred: %s", e) raise HTTPException( @@ -147,7 +159,9 @@ def graph_rag_recall_api(req: GraphRAGRequest): @router.post("/config/graph", status_code=status.HTTP_201_CREATED) def graph_config_api(req: GraphConfigRequest): # Accept status code - res = apply_graph_conf(req.url, req.name, req.user, req.pwd, req.gs, origin_call="http") + res = apply_graph_conf( + req.url, req.name, req.user, req.pwd, req.gs, origin_call="http" + ) return generate_response(RAGResponse(status_code=res, message="Missing Value")) # TODO: restructure the implement of llm to three types, like "/config/chat_llm" @@ -157,10 +171,16 @@ def llm_config_api(req: LLMConfigRequest): if req.llm_type == "openai": res = apply_llm_conf( - req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http" + req.api_key, + req.api_base, + req.language_model, + req.max_tokens, + origin_call="http", ) else: - res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") + res = apply_llm_conf( + req.host, req.port, req.language_model, None, origin_call="http" + ) return generate_response(RAGResponse(status_code=res, message="Missing Value")) @router.post("/config/embedding", status_code=status.HTTP_201_CREATED) @@ -172,7 +192,9 @@ def embedding_config_api(req: LLMConfigRequest): req.api_key, req.api_base, req.language_model, origin_call="http" ) else: - res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") + res = apply_embedding_conf( + req.host, req.port, req.language_model, origin_call="http" + ) return generate_response(RAGResponse(status_code=res, message="Missing Value")) @router.post("/config/rerank", status_code=status.HTTP_201_CREATED) @@ -184,7 +206,9 @@ def rerank_config_api(req: RerankerConfigRequest): req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http" ) elif req.reranker_type == "siliconflow": - res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http") + res = apply_reranker_conf( + req.api_key, req.reranker_model, None, origin_call="http" + ) else: res = status.HTTP_501_NOT_IMPLEMENTED return generate_response(RAGResponse(status_code=res, message="Missing Value")) @@ -197,20 +221,20 @@ def text2gremlin_api(req: GremlinGenerateRequest): # Basic parameter validation: empty query => 400 if not req.query or not str(req.query).strip(): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Query must not be empty." + status_code=status.HTTP_400_BAD_REQUEST, + detail="Query must not be empty.", ) output_types_str_list = None if req.output_types: output_types_str_list = [ot.value for ot in req.output_types] - response_dict = SchedulerSingleton.get_instance().schedule_flow( - "text2gremlin", - req.query, - req.example_num, - huge_settings.graph_name, - req.gremlin_prompt, - output_types_str_list, + response_dict = gremlin_generate_selective_func( + inp=req.query, + example_num=req.example_num, + schema_input=huge_settings.graph_name, + gremlin_prompt_input=req.gremlin_prompt, + requested_outputs=output_types_str_list, ) return response_dict except HTTPException as e: diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py index ca36867d9..9bf04b570 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py @@ -19,12 +19,12 @@ import os from typing import AsyncGenerator, Literal, Optional, Tuple - -import gradio as gr -from hugegraph_llm.flows.scheduler import SchedulerSingleton import pandas as pd +import gradio as gr from gradio.utils import NamedString +from hugegraph_llm.flows import FlowName +from hugegraph_llm.flows.scheduler import SchedulerSingleton from hugegraph_llm.config import resource_path, prompt, llm_settings from hugegraph_llm.utils.decorators import with_task_id from hugegraph_llm.utils.log import log @@ -51,11 +51,7 @@ def rag_answer( ) -> Tuple: """ Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline. - 1. Initialize the RAGPipeline. - 2. Select vector search or graph search based on parameters. - 3. Merge, deduplicate, and rerank the results. - 4. Synthesize the final answer. - 5. Run the pipeline and return the results. + Fetch the Scheduler to deal with the request """ graph_search, gremlin_prompt, vector_search = update_ui_configs( answer_prompt, @@ -75,13 +71,13 @@ def rag_answer( try: # Select workflow by mode to avoid fetching the wrong pipeline from the pool if graph_vector_answer or (graph_only_answer and vector_only_answer): - flow_key = "rag_graph_vector" + flow_key = FlowName.RAG_GRAPH_VECTOR elif vector_only_answer: - flow_key = "rag_vector_only" + flow_key = FlowName.RAG_VECTOR_ONLY elif graph_only_answer: - flow_key = "rag_graph_only" + flow_key = FlowName.RAG_GRAPH_ONLY elif raw_answer: - flow_key = "rag_raw" + flow_key = FlowName.RAG_RAW else: raise RuntimeError("Unsupported flow type") @@ -172,11 +168,7 @@ async def rag_answer_streaming( ) -> AsyncGenerator[Tuple[str, str, str, str], None]: """ Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline. - 1. Initialize the RAGPipeline. - 2. Select vector search or graph search based on parameters. - 3. Merge, deduplicate, and rerank the results. - 4. Synthesize the final answer. - 5. Run the pipeline and return the results. + Fetch the Scheduler to deal with the request """ graph_search, gremlin_prompt, vector_search = update_ui_configs( answer_prompt, @@ -197,13 +189,13 @@ async def rag_answer_streaming( # Select the specific streaming workflow scheduler = SchedulerSingleton.get_instance() if graph_vector_answer or (graph_only_answer and vector_only_answer): - flow_key = "rag_graph_vector" + flow_key = FlowName.RAG_GRAPH_VECTOR elif vector_only_answer: - flow_key = "rag_vector_only" + flow_key = FlowName.RAG_VECTOR_ONLY elif graph_only_answer: - flow_key = "rag_graph_only" + flow_key = FlowName.RAG_GRAPH_ONLY elif raw_answer: - flow_key = "rag_raw" + flow_key = FlowName.RAG_RAW else: raise RuntimeError("Unsupported flow type") @@ -367,7 +359,7 @@ def toggle_slider(enable): ) gr.Markdown( - """## 2. (Batch) Back-testing ) + """## 2. (Batch) Back-testing > 1. Download the template file & fill in the questions you want to test. > 2. Upload the file & click the button to generate answers. (Preview shows the first 40 lines) > 3. The answer options are the same as the above RAG/Q&A frame diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py index 6600d7c41..aa9c2f0c5 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py @@ -25,11 +25,7 @@ import pandas as pd from hugegraph_llm.config import prompt, resource_path, huge_settings -from hugegraph_llm.models.embeddings.init_embedding import Embeddings -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.graph_rag_task import RAGPipeline -from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator -from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager +from hugegraph_llm.flows import FlowName from hugegraph_llm.utils.embedding_utils import get_index_folder_name from hugegraph_llm.utils.hugegraph_utils import run_gremlin_query from hugegraph_llm.utils.log import log @@ -86,7 +82,9 @@ def store_schema(schema, question, gremlin_prompt): def build_example_vector_index(temp_file) -> dict: - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) index_path = os.path.join(resource_path, folder_name, "gremlin_examples") if not os.path.exists(index_path): os.makedirs(index_path) @@ -98,7 +96,9 @@ def build_example_vector_index(temp_file) -> dict: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") _, file_name = os.path.split(f"{name}_{timestamp}{ext}") log.info("Copying file to: %s", file_name) - target_file = os.path.join(resource_path, folder_name, "gremlin_examples", file_name) + target_file = os.path.join( + resource_path, folder_name, "gremlin_examples", file_name + ) try: import shutil @@ -116,11 +116,10 @@ def build_example_vector_index(temp_file) -> dict: else: log.critical("Unsupported file format. Please input a JSON or CSV file.") return {"error": "Unsupported file format. Please input a JSON or CSV file."} - builder = GremlinGenerator( - llm=LLMs().get_text2gql_llm(), - embedding=Embeddings().get_embedding(), + + return SchedulerSingleton.get_instance().schedule_flow( + FlowName.BUILD_EXAMPLES_INDEX, examples ) - return builder.example_index_build(examples).run() def _process_schema(schema, generator, sm): @@ -182,43 +181,6 @@ def _execute_queries(context, output_types): context["raw_exec_res"] = "" -def gremlin_generate( - inp, example_num, schema, gremlin_prompt, requested_outputs: Optional[List[str]] = None -) -> GremlinResult: - generator = GremlinGenerator( - llm=LLMs().get_text2gql_llm(), embedding=Embeddings().get_embedding() - ) - sm = SchemaManager(graph_name=schema) - - processed_schema, short_schema = _process_schema(schema, generator, sm) - if processed_schema is None and short_schema is None: - return GremlinResult.error("Invalid JSON schema, please check the format carefully.") - - updated_schema = sm.simple_schema(processed_schema) if short_schema else processed_schema - store_schema(str(updated_schema), inp, gremlin_prompt) - - output_types = _configure_output_types(requested_outputs) - - context = ( - generator.example_index_query(example_num) - .gremlin_generate_synthesize(updated_schema, gremlin_prompt) - .run(query=inp) - ) - - _execute_queries(context, output_types) - - match_result = json.dumps( - context.get("match_result", "No Results"), ensure_ascii=False, indent=2 - ) - return GremlinResult.success_result( - match_result=match_result, - template_gremlin=context["result"], - raw_gremlin=context["raw_result"], - template_exec=context["template_exec_res"], - raw_exec=context["raw_exec_res"], - ) - - def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]: mini_schema = {} @@ -226,7 +188,11 @@ def simple_schema(schema: Dict[str, Any]) -> Dict[str, Any]: if "vertexlabels" in schema: mini_schema["vertexlabels"] = [] for vertex in schema["vertexlabels"]: - new_vertex = {key: vertex[key] for key in ["id", "name", "properties"] if key in vertex} + new_vertex = { + key: vertex[key] + for key in ["id", "name", "properties"] + if key in vertex + } mini_schema["vertexlabels"].append(new_vertex) # Add necessary edgelabels items (4) @@ -248,7 +214,7 @@ def gremlin_generate_for_ui(inp, example_num, schema, gremlin_prompt): # Execute via scheduler try: res = SchedulerSingleton.get_instance().schedule_flow( - "text2gremlin", + FlowName.TEXT2GREMLIN, inp, int(example_num) if isinstance(example_num, (int, float, str)) else 2, schema, @@ -305,15 +271,21 @@ def create_text2gremlin_block() -> Tuple: with gr.Row(): with gr.Column(scale=1): input_box = gr.Textbox( - value=prompt.default_question, label="Nature Language Query", show_copy_button=True + value=prompt.default_question, + label="Nature Language Query", + show_copy_button=True, ) match = gr.Code( label="Similar Template (TopN)", language="javascript", elem_classes="code-container-show", ) - initialized_out = gr.Textbox(label="Gremlin With Template", show_copy_button=True) - raw_out = gr.Textbox(label="Gremlin Without Template", show_copy_button=True) + initialized_out = gr.Textbox( + label="Gremlin With Template", show_copy_button=True + ) + raw_out = gr.Textbox( + label="Gremlin Without Template", show_copy_button=True + ) tmpl_exec_out = gr.Code( label="Query With Template Output", language="json", @@ -330,7 +302,10 @@ def create_text2gremlin_block() -> Tuple: minimum=0, maximum=10, step=1, value=2, label="Number of refer examples" ) schema_box = gr.Textbox( - value=prompt.text2gql_graph_schema, label="Schema", lines=2, show_copy_button=True + value=prompt.text2gql_graph_schema, + label="Schema", + lines=2, + show_copy_button=True, ) prompt_box = gr.Textbox( value=prompt.gremlin_generate_prompt, @@ -362,24 +337,21 @@ def graph_rag_recall( get_vertex_only: bool = False, ) -> dict: store_schema(prompt.text2gql_graph_schema, query, gremlin_prompt) - rag = RAGPipeline() - rag.extract_keywords().keywords_to_vid( + context = SchedulerSingleton.get_instance().schedule_flow( + FlowName.RAG_GRAPH_ONLY, + query=query, + gremlin_tmpl_num=gremlin_tmpl_num, + rerank_method=rerank_method, + near_neighbor_first=near_neighbor_first, + custom_related_information=custom_related_information, + gremlin_prompt=gremlin_prompt, + max_graph_items=max_graph_items, + topk_return_results=topk_return_results, vector_dis_threshold=vector_dis_threshold, topk_per_keyword=topk_per_keyword, + is_graph_rag_recall=True, + is_vector_only=get_vertex_only, ) - - if not get_vertex_only: - rag.import_schema(huge_settings.graph_name).query_graphdb( - num_gremlin_generate_example=gremlin_tmpl_num, - gremlin_prompt=gremlin_prompt, - max_graph_items=max_graph_items, - ).merge_dedup_rerank( - rerank_method=rerank_method, - near_neighbor_first=near_neighbor_first, - custom_related_information=custom_related_information, - topk_return_results=topk_return_results, - ) - context = rag.run(verbose=True, query=query, graph_search=True) return context @@ -390,45 +362,13 @@ def gremlin_generate_selective( gremlin_prompt_input: str, requested_outputs: Optional[List[str]] = None, ) -> Dict[str, Any]: - """ - Wraps the gremlin_generate function to return a dictionary of outputs - based on the requested_outputs list of strings. - """ - output_keys = [ - "match_result", - "template_gremlin", - "raw_gremlin", - "template_execution_result", - "raw_execution_result", - ] - if not requested_outputs: # None or empty list - requested_outputs = output_keys - - result = gremlin_generate( - inp, example_num, schema_input, gremlin_prompt_input, requested_outputs + response_dict = SchedulerSingleton.get_instance().schedule_flow( + FlowName.TEXT2GREMLIN, + inp, + example_num, + schema_input, + gremlin_prompt_input, + requested_outputs, ) - outputs_dict: Dict[str, Any] = {} - - if not result.success: - # Handle error case - if "match_result" in requested_outputs: - outputs_dict["match_result"] = result.match_result - if result.error_message: - outputs_dict["error_detail"] = result.error_message - return outputs_dict - - # Handle successful case - output_mapping = { - "match_result": result.match_result, - "template_gremlin": result.template_gremlin, - "raw_gremlin": result.raw_gremlin, - "template_execution_result": result.template_exec_result, - "raw_execution_result": result.raw_exec_result, - } - - for key in requested_outputs: - if key in output_mapping: - outputs_dict[key] = output_mapping[key] - - return outputs_dict + return response_dict diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py index 56b5de4b3..84d60df7e 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/vector_graph_block.py @@ -26,6 +26,7 @@ from hugegraph_llm.config import huge_settings from hugegraph_llm.config import prompt from hugegraph_llm.config import resource_path +from hugegraph_llm.flows import FlowName from hugegraph_llm.flows.scheduler import SchedulerSingleton from hugegraph_llm.utils.graph_index_utils import ( get_graph_index_info, @@ -63,12 +64,16 @@ def generate_prompt_for_ui(source_text, scenario, example_name): Handles the UI logic for generating a new prompt using the new workflow architecture. """ if not all([source_text, scenario, example_name]): - gr.Warning("Please provide original text, expected scenario, and select an example!") + gr.Warning( + "Please provide original text, expected scenario, and select an example!" + ) return gr.update() try: # using new architecture scheduler = SchedulerSingleton.get_instance() - result = scheduler.schedule_flow("prompt_generate", source_text, scenario, example_name) + result = scheduler.schedule_flow( + FlowName.PROMPT_GENERATE, source_text, scenario, example_name + ) gr.Info("Prompt generated successfully!") return result except Exception as e: @@ -79,7 +84,9 @@ def generate_prompt_for_ui(source_text, scenario, example_name): def load_example_names(): """Load all candidate examples""" try: - examples_path = os.path.join(resource_path, "prompt_examples", "prompt_examples.json") + examples_path = os.path.join( + resource_path, "prompt_examples", "prompt_examples.json" + ) with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return [example.get("name", "Unnamed example") for example in examples] @@ -100,16 +107,22 @@ def load_query_examples(): ), ) if language.upper() == "CN": - examples_path = os.path.join(resource_path, "prompt_examples", "query_examples_CN.json") + examples_path = os.path.join( + resource_path, "prompt_examples", "query_examples_CN.json" + ) else: - examples_path = os.path.join(resource_path, "prompt_examples", "query_examples.json") + examples_path = os.path.join( + resource_path, "prompt_examples", "query_examples.json" + ) with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) except (FileNotFoundError, json.JSONDecodeError): try: - examples_path = os.path.join(resource_path, "prompt_examples", "query_examples.json") + examples_path = os.path.join( + resource_path, "prompt_examples", "query_examples.json" + ) with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) @@ -120,7 +133,9 @@ def load_query_examples(): def load_schema_fewshot_examples(): """Load few-shot examples from a JSON file""" try: - examples_path = os.path.join(resource_path, "prompt_examples", "schema_examples.json") + examples_path = os.path.join( + resource_path, "prompt_examples", "schema_examples.json" + ) with open(examples_path, "r", encoding="utf-8") as f: examples = json.load(f) return json.dumps(examples, indent=2, ensure_ascii=False) @@ -131,10 +146,14 @@ def load_schema_fewshot_examples(): def update_example_preview(example_name): """Update the display content based on the selected example name.""" try: - examples_path = os.path.join(resource_path, "prompt_examples", "prompt_examples.json") + examples_path = os.path.join( + resource_path, "prompt_examples", "prompt_examples.json" + ) with open(examples_path, "r", encoding="utf-8") as f: all_examples = json.load(f) - selected_example = next((ex for ex in all_examples if ex.get("name") == example_name), None) + selected_example = next( + (ex for ex in all_examples if ex.get("name") == example_name), None + ) if selected_example: return ( @@ -179,7 +198,9 @@ def _create_prompt_helper_block(demo, input_text, info_extract_template): interactive=False, ) - generate_prompt_btn = gr.Button("🚀 Auto-generate Graph Extract Prompt", variant="primary") + generate_prompt_btn = gr.Button( + "🚀 Auto-generate Graph Extract Prompt", variant="primary" + ) # Bind the change event of the dropdown menu few_shot_dropdown.change( fn=update_example_preview, @@ -271,7 +292,9 @@ def create_vector_graph_block(): lines=15, max_lines=29, ) - out = gr.Code(label="Output Info", language="json", elem_classes="code-container-edit") + out = gr.Code( + label="Output Info", language="json", elem_classes="code-container-edit" + ) with gr.Row(): with gr.Accordion("Get RAG Info", open=False): @@ -280,8 +303,12 @@ def create_vector_graph_block(): graph_index_btn0 = gr.Button("Get Graph Index Info", size="sm") with gr.Accordion("Clear RAG Data", open=False): with gr.Column(): - vector_index_btn1 = gr.Button("Clear Chunks Vector Index", size="sm") - graph_index_btn1 = gr.Button("Clear Graph Vid Vector Index", size="sm") + vector_index_btn1 = gr.Button( + "Clear Chunks Vector Index", size="sm" + ) + graph_index_btn1 = gr.Button( + "Clear Graph Vid Vector Index", size="sm" + ) graph_data_btn0 = gr.Button("Clear Graph Data", size="sm") vector_import_bt = gr.Button("Import into Vector", variant="primary") @@ -354,9 +381,9 @@ def create_vector_graph_block(): inputs=[input_text, input_schema, info_extract_template], ) - graph_loading_bt.click(import_graph_data, inputs=[out, input_schema], outputs=[out]).then( - update_vid_embedding - ).then( + graph_loading_bt.click( + import_graph_data, inputs=[out, input_schema], outputs=[out] + ).then(update_vid_embedding).then( store_prompt, inputs=[input_text, input_schema, info_extract_template], ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/__init__.py b/hugegraph-llm/src/hugegraph_llm/flows/__init__.py index 13a83393a..1016680b5 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/__init__.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/__init__.py @@ -14,3 +14,21 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +from enum import Enum + + +class FlowName(str, Enum): + RAG_GRAPH_ONLY = "rag_graph_only" + RAG_VECTOR_ONLY = "rag_vector_only" + TEXT2GREMLIN = "text2gremlin" + BUILD_EXAMPLES_INDEX = "build_examples_index" + BUILD_VECTOR_INDEX = "build_vector_index" + GRAPH_EXTRACT = "graph_extract" + IMPORT_GRAPH_DATA = "import_graph_data" + UPDATE_VID_EMBEDDINGS = "update_vid_embeddings" + GET_GRAPH_INDEX_INFO = "get_graph_index_info" + BUILD_SCHEMA = "build_schema" + PROMPT_GENERATE = "prompt_generate" + RAG_RAW = "rag_raw" + RAG_GRAPH_VECTOR = "rag_graph_vector" diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_example_index.py b/hugegraph-llm/src/hugegraph_llm/flows/build_example_index.py new file mode 100644 index 000000000..d09cc7828 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_example_index.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import List, Dict, Optional + +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.nodes.index_node.build_gremlin_example_index import ( + BuildGremlinExampleIndexNode, +) +from hugegraph_llm.utils.log import log + + +# pylint: disable=arguments-differ,keyword-arg-before-vararg +class BuildExampleIndexFlow(BaseFlow): + def __init__(self): + pass + + def prepare( + self, + prepared_input: WkFlowInput, + examples: Optional[List[Dict[str, str]]], + **kwargs, + ): + prepared_input.examples = examples + + def build_flow(self, examples=None, **kwargs): + pipeline = GPipeline() + prepared_input = WkFlowInput() + self.prepare(prepared_input, examples=examples) + + pipeline.createGParam(prepared_input, "wkflow_input") + pipeline.createGParam(WkFlowState(), "wkflow_state") + + build_node = BuildGremlinExampleIndexNode() + pipeline.registerGElement(build_node, set(), "build_examples_index") + + return pipeline + + def post_deal(self, pipeline=None, **kwargs): + state_json = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + try: + formatted_schema = json.dumps(state_json, ensure_ascii=False, indent=2) + return formatted_schema + except (TypeError, ValueError) as e: + log.error("Failed to format schema: %s", e) + return str(state_json) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py b/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py index 6bbcb8512..1554e53fe 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_schema.py @@ -13,15 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + +from PyCGraph import GPipeline + from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.nodes.llm_node.schema_build import SchemaBuildNode from hugegraph_llm.utils.log import log -import json -from PyCGraph import GPipeline - +# pylint: disable=arguments-differ,keyword-arg-before-vararg class BuildSchemaFlow(BaseFlow): def __init__(self): pass @@ -32,15 +34,17 @@ def prepare( texts=None, query_examples=None, few_shot_schema=None, + **kwargs, ): prepared_input.texts = texts # Optional fields packed into wk_input for SchemaBuildNode # Keep raw values; node will parse if strings prepared_input.query_examples = query_examples prepared_input.few_shot_schema = few_shot_schema - return - def build_flow(self, texts=None, query_examples=None, few_shot_schema=None): + def build_flow( + self, texts=None, query_examples=None, few_shot_schema=None, **kwargs + ): pipeline = GPipeline() prepared_input = WkFlowInput() self.prepare( @@ -58,7 +62,7 @@ def build_flow(self, texts=None, query_examples=None, few_shot_schema=None): return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): state_json = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() if "schema" not in state_json: return "" diff --git a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py index 9a07b5dba..b57cbfbe3 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/build_vector_index.py @@ -13,28 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json + +from PyCGraph import GPipeline + from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.document_node.chunk_split import ChunkSplitNode from hugegraph_llm.nodes.index_node.build_vector_index import BuildVectorIndexNode from hugegraph_llm.state.ai_state import WkFlowInput - -import json -from PyCGraph import GPipeline - from hugegraph_llm.state.ai_state import WkFlowState +# pylint: disable=arguments-differ,keyword-arg-before-vararg class BuildVectorIndexFlow(BaseFlow): def __init__(self): pass - def prepare(self, prepared_input: WkFlowInput, texts): + def prepare(self, prepared_input: WkFlowInput, texts, **kwargs): prepared_input.texts = texts prepared_input.language = "zh" prepared_input.split_type = "paragraph" - return - def build_flow(self, texts): + def build_flow(self, texts, **kwargs): pipeline = GPipeline() # prepare for workflow input prepared_input = WkFlowInput() @@ -50,6 +50,6 @@ def build_flow(self, texts): return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() return json.dumps(res, ensure_ascii=False, indent=2) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/common.py b/hugegraph-llm/src/hugegraph_llm/flows/common.py index e2348466c..d1301119e 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/common.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/common.py @@ -26,25 +26,22 @@ class BaseFlow(ABC): """ @abstractmethod - def prepare(self, prepared_input: WkFlowInput, *args, **kwargs): + def prepare(self, prepared_input: WkFlowInput, **kwargs): """ Pre-processing interface. """ - pass @abstractmethod - def build_flow(self, *args, **kwargs): + def build_flow(self, **kwargs): """ Interface for building the flow. """ - pass @abstractmethod - def post_deal(self, *args, **kwargs): + def post_deal(self, **kwargs): """ Post-processing interface. """ - pass async def post_deal_stream( self, pipeline=None @@ -57,15 +54,11 @@ async def post_deal_stream( if pipeline is None: yield {"error": "No pipeline provided"} return - try: - state_json = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - log.info(f"{flow_name} post processing success") - stream_flow = state_json.get("stream_generator") - if stream_flow is None: - yield {"error": "No stream_generator found in workflow state"} - return - async for chunk in stream_flow: - yield chunk - except Exception as e: - log.error(f"{flow_name} post processing failed: {e}") - yield {"error": f"Post processing failed: {str(e)}"} + state_json = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("%s post processing success", flow_name) + stream_flow = state_json.get("stream_generator") + if stream_flow is None: + yield {"error": "No stream_generator found in workflow state"} + return + async for chunk in stream_flow: + yield chunk diff --git a/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py index 7d2735352..86d08bf2d 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/get_graph_index_info.py @@ -16,30 +16,32 @@ import json import os +from PyCGraph import GPipeline + from hugegraph_llm.config import huge_settings, llm_settings, resource_path from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.init_embedding import model_map from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.nodes.hugegraph_node.fetch_graph_data import FetchGraphDataNode -from PyCGraph import GPipeline from hugegraph_llm.utils.embedding_utils import ( get_filename_prefix, get_index_folder_name, ) +# pylint: disable=arguments-differ,keyword-arg-before-vararg class GetGraphIndexInfoFlow(BaseFlow): def __init__(self): pass - def prepare(self, prepared_input: WkFlowInput, *args, **kwargs): + def prepare(self, prepared_input: WkFlowInput, **kwargs): return - def build_flow(self, *args, **kwargs): + def build_flow(self, **kwargs): pipeline = GPipeline() prepared_input = WkFlowInput() - self.prepare(prepared_input, *args, **kwargs) + self.prepare(prepared_input, **kwargs) pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") fetch_node = FetchGraphDataNode() @@ -48,7 +50,9 @@ def build_flow(self, *args, **kwargs): def post_deal(self, pipeline=None): graph_summary_info = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) index_dir = str(os.path.join(resource_path, folder_name, "graph_vids")) filename_prefix = get_filename_prefix( llm_settings.embedding_type, @@ -56,7 +60,7 @@ def post_deal(self, pipeline=None): ) try: vector_index = VectorIndex.from_index_file(index_dir, filename_prefix) - except FileNotFoundError: + except (RuntimeError, OSError): return json.dumps(graph_summary_info, ensure_ascii=False, indent=2) graph_summary_info["vid_index"] = { "embed_dim": vector_index.index.d, diff --git a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py index 55f53b7ad..f3d166786 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/graph_extract.py @@ -23,25 +23,38 @@ from hugegraph_llm.utils.log import log +# pylint: disable=arguments-differ,keyword-arg-before-vararg class GraphExtractFlow(BaseFlow): def __init__(self): pass - def prepare(self, prepared_input: WkFlowInput, schema, texts, example_prompt, extract_type): + def prepare( + self, + prepared_input: WkFlowInput, + schema, + texts, + example_prompt, + extract_type, + language="zh", + **kwargs, + ): # prepare input data prepared_input.texts = texts - prepared_input.language = "zh" + prepared_input.language = language prepared_input.split_type = "document" prepared_input.example_prompt = example_prompt prepared_input.schema = schema prepared_input.extract_type = extract_type - return - def build_flow(self, schema, texts, example_prompt, extract_type): + def build_flow( + self, schema, texts, example_prompt, extract_type, language="zh", **kwargs + ): pipeline = GPipeline() prepared_input = WkFlowInput() # prepare input data - self.prepare(prepared_input, schema, texts, example_prompt, extract_type) + self.prepare( + prepared_input, schema, texts, example_prompt, extract_type, language + ) pipeline.createGParam(prepared_input, "wkflow_input") pipeline.createGParam(WkFlowState(), "wkflow_state") @@ -57,7 +70,7 @@ def build_flow(self, schema, texts, example_prompt, extract_type): return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() vertices = res.get("vertices", []) edges = res.get("edges", []) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py index 0b29b4e64..d0e34ac59 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/import_graph_data.py @@ -24,11 +24,12 @@ from hugegraph_llm.utils.log import log +# pylint: disable=arguments-differ,keyword-arg-before-vararg class ImportGraphDataFlow(BaseFlow): def __init__(self): pass - def prepare(self, prepared_input: WkFlowInput, data, schema): + def prepare(self, prepared_input: WkFlowInput, data, schema, **kwargs): try: data_json = json.loads(data.strip()) if isinstance(data, str) else data except json.JSONDecodeError as e: @@ -43,9 +44,8 @@ def prepare(self, prepared_input: WkFlowInput, data, schema): ) prepared_input.data_json = data_json prepared_input.schema = schema - return - def build_flow(self, data, schema): + def build_flow(self, data, schema, **kwargs): pipeline = GPipeline() prepared_input = WkFlowInput() # prepare input data @@ -61,7 +61,7 @@ def build_flow(self, data, schema): return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() gr.Info("Import graph data successfully!") return json.dumps(res, ensure_ascii=False, indent=2) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py index b4a7bf329..16618e13a 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/prompt_generate.py @@ -13,29 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +from PyCGraph import GPipeline + from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.llm_node.prompt_generate import PromptGenerateNode from hugegraph_llm.state.ai_state import WkFlowInput - -from PyCGraph import GPipeline - from hugegraph_llm.state.ai_state import WkFlowState +# pylint: disable=arguments-differ,keyword-arg-before-vararg class PromptGenerateFlow(BaseFlow): def __init__(self): pass - def prepare(self, prepared_input: WkFlowInput, source_text, scenario, example_name): + def prepare( + self, prepared_input: WkFlowInput, source_text, scenario, example_name, **kwargs + ): """ Prepare input data for PromptGenerate workflow """ prepared_input.source_text = source_text prepared_input.scenario = scenario prepared_input.example_name = example_name - return - def build_flow(self, source_text, scenario, example_name): + def build_flow(self, source_text, scenario, example_name, **kwargs): """ Build the PromptGenerate workflow """ @@ -53,9 +54,11 @@ def build_flow(self, source_text, scenario, example_name): return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): """ Process the execution result of PromptGenerate workflow """ res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - return res.get("generated_extract_prompt", "Generation failed. Please check the logs.") + return res.get( + "generated_extract_prompt", "Generation failed. Please check the logs." + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py index 5feb3d471..3029b6259 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_only.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -from typing import Optional, Literal +from typing import Optional, Literal, cast -from PyCGraph import GPipeline +from PyCGraph import GPipeline, GRegion, GCondition from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.llm_node.keyword_extract_node import KeywordExtractNode @@ -31,6 +30,23 @@ from hugegraph_llm.utils.log import log +class GraphRecallCondition(GCondition): + def choose(self): + prepared_input: WkFlowInput = cast( + WkFlowInput, self.getGParamWithNoEmpty("wkflow_input") + ) + return 0 if prepared_input.is_graph_rag_recall else 1 + + +class VectorOnlyCondition(GCondition): + def choose(self): + prepared_input: WkFlowInput = cast( + WkFlowInput, self.getGParamWithNoEmpty("wkflow_input") + ) + return 0 if prepared_input.is_vector_only else 1 + + +# pylint: disable=arguments-differ,keyword-arg-before-vararg class RAGGraphOnlyFlow(BaseFlow): """ Workflow for graph-only answering (graph_only_answer) @@ -40,13 +56,12 @@ def prepare( self, prepared_input: WkFlowInput, query: str, - vector_search: bool = None, - graph_search: bool = None, - raw_answer: bool = None, - vector_only_answer: bool = None, - graph_only_answer: bool = None, - graph_vector_answer: bool = None, - graph_ratio: float = 0.5, + vector_search: bool = False, + graph_search: bool = True, + raw_answer: bool = False, + vector_only_answer: bool = False, + graph_only_answer: bool = True, + graph_vector_answer: bool = False, rerank_method: Literal["bleu", "reranker"] = "bleu", near_neighbor_first: bool = False, custom_related_information: str = "", @@ -54,11 +69,13 @@ def prepare( keywords_extract_prompt: Optional[str] = None, gremlin_tmpl_num: Optional[int] = -1, gremlin_prompt: Optional[str] = None, - max_graph_items: int = None, - topk_return_results: int = None, - vector_dis_threshold: float = None, - topk_per_keyword: int = None, - **_: dict, + max_graph_items: Optional[int] = None, + topk_return_results: Optional[int] = None, + vector_dis_threshold: Optional[float] = None, + topk_per_keyword: Optional[int] = None, + is_graph_rag_recall: bool = False, + is_vector_only: bool = False, + **kwargs, ): prepared_input.query = query prepared_input.vector_search = vector_search @@ -90,13 +107,15 @@ def prepare( ) prepared_input.schema = huge_settings.graph_name + prepared_input.is_graph_rag_recall = is_graph_rag_recall + prepared_input.is_vector_only = is_vector_only prepared_input.data_json = { "query": query, "vector_search": vector_search, "graph_search": graph_search, "max_graph_items": max_graph_items or huge_settings.max_graph_items, + "is_graph_rag_recall": is_graph_rag_recall, } - return def build_flow(self, **kwargs): pipeline = GPipeline() @@ -106,48 +125,49 @@ def build_flow(self, **kwargs): pipeline.createGParam(WkFlowState(), "wkflow_state") # Create nodes and register them with registerGElement - only_keyword_extract_node = KeywordExtractNode() - only_semantic_id_query_node = SemanticIdQueryNode() + only_keyword_extract_node = KeywordExtractNode("only_keyword") + only_semantic_id_query_node = SemanticIdQueryNode( + {only_keyword_extract_node}, "only_semantic" + ) + vector_region: GRegion = GRegion( + [only_keyword_extract_node, only_semantic_id_query_node] + ) + only_schema_node = SchemaNode() - only_graph_query_node = GraphQueryNode() - merge_rerank_node = MergeRerankNode() + schema_node = VectorOnlyCondition([GRegion(), only_schema_node]) + only_graph_query_node = GraphQueryNode("only_graph") + merge_rerank_node = MergeRerankNode({only_graph_query_node}, "merge_rerank") + graph_region: GRegion = GRegion([only_graph_query_node, merge_rerank_node]) + graph_condition_region = VectorOnlyCondition([GRegion(), graph_region]) + answer_synthesize_node = AnswerSynthesizeNode() + answer_node = GraphRecallCondition([GRegion(), answer_synthesize_node]) - pipeline.registerGElement(only_keyword_extract_node, set(), "only_keyword") + pipeline.registerGElement(vector_region, set(), "vector_fetch") + pipeline.registerGElement(schema_node, set(), "schema_condition") pipeline.registerGElement( - only_semantic_id_query_node, {only_keyword_extract_node}, "only_semantic" + graph_condition_region, + {schema_node, vector_region}, + "graph_condition", ) - pipeline.registerGElement(only_schema_node, set(), "only_schema") pipeline.registerGElement( - only_graph_query_node, - {only_schema_node, only_semantic_id_query_node}, - "only_graph", + answer_node, {graph_condition_region}, "answer_condition" ) - pipeline.registerGElement( - merge_rerank_node, {only_graph_query_node}, "merge_one" - ) - pipeline.registerGElement(answer_synthesize_node, {merge_rerank_node}, "graph") log.info("RAGGraphOnlyFlow pipeline built successfully") return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): if pipeline is None: - return json.dumps( - {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 - ) - try: - res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - log.info("RAGGraphOnlyFlow post processing success") - return { + return {"error": "No pipeline provided"} + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGGraphOnlyFlow post processing success") + return ( + { "raw_answer": res.get("raw_answer", ""), "vector_only_answer": res.get("vector_only_answer", ""), "graph_only_answer": res.get("graph_only_answer", ""), "graph_vector_answer": res.get("graph_vector_answer", ""), } - except Exception as e: - log.error(f"RAGGraphOnlyFlow post processing failed: {e}") - return json.dumps( - {"error": f"Post processing failed: {str(e)}"}, - ensure_ascii=False, - indent=2, - ) + if not res.get("is_graph_rag_recall", False) + else res + ) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py index 2f4a2bfa2..96c4ab858 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_graph_vector.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from typing import Optional, Literal @@ -32,6 +31,7 @@ from hugegraph_llm.utils.log import log +# pylint: disable=arguments-differ,keyword-arg-before-vararg class RAGGraphVectorFlow(BaseFlow): """ Workflow for graph + vector hybrid answering (graph_vector_answer) @@ -41,12 +41,12 @@ def prepare( self, prepared_input: WkFlowInput, query: str, - vector_search: bool = None, - graph_search: bool = None, - raw_answer: bool = None, - vector_only_answer: bool = None, - graph_only_answer: bool = None, - graph_vector_answer: bool = None, + vector_search: bool = True, + graph_search: bool = True, + raw_answer: bool = False, + vector_only_answer: bool = False, + graph_only_answer: bool = False, + graph_vector_answer: bool = True, graph_ratio: float = 0.5, rerank_method: Literal["bleu", "reranker"] = "bleu", near_neighbor_first: bool = False, @@ -55,11 +55,11 @@ def prepare( keywords_extract_prompt: Optional[str] = None, gremlin_tmpl_num: Optional[int] = -1, gremlin_prompt: Optional[str] = None, - max_graph_items: int = None, - topk_return_results: int = None, - vector_dis_threshold: float = None, - topk_per_keyword: int = None, - **_: dict, + max_graph_items: Optional[int] = None, + topk_return_results: Optional[int] = None, + vector_dis_threshold: Optional[float] = None, + topk_per_keyword: Optional[int] = None, + **kwargs, ): prepared_input.query = query prepared_input.vector_search = vector_search @@ -98,7 +98,6 @@ def prepare( "graph_search": graph_search, "max_graph_items": max_graph_items or huge_settings.max_graph_items, } - return def build_flow(self, **kwargs): pipeline = GPipeline() @@ -135,24 +134,14 @@ def build_flow(self, **kwargs): log.info("RAGGraphVectorFlow pipeline built successfully") return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): if pipeline is None: - return json.dumps( - {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 - ) - try: - res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - log.info("RAGGraphVectorFlow post processing success") - return { - "raw_answer": res.get("raw_answer", ""), - "vector_only_answer": res.get("vector_only_answer", ""), - "graph_only_answer": res.get("graph_only_answer", ""), - "graph_vector_answer": res.get("graph_vector_answer", ""), - } - except Exception as e: - log.error(f"RAGGraphVectorFlow post processing failed: {e}") - return json.dumps( - {"error": f"Post processing failed: {str(e)}"}, - ensure_ascii=False, - indent=2, - ) + return {"error": "No pipeline provided"} + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGGraphVectorFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py index f62e574bb..ede8f98e1 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_raw.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from typing import Optional @@ -26,6 +25,7 @@ from hugegraph_llm.utils.log import log +# pylint: disable=arguments-differ,keyword-arg-before-vararg class RAGRawFlow(BaseFlow): """ Workflow for basic LLM answering only (raw_answer) @@ -35,16 +35,16 @@ def prepare( self, prepared_input: WkFlowInput, query: str, - vector_search: bool = None, - graph_search: bool = None, - raw_answer: bool = None, - vector_only_answer: bool = None, - graph_only_answer: bool = None, - graph_vector_answer: bool = None, + vector_search: bool = False, + graph_search: bool = False, + raw_answer: bool = True, + vector_only_answer: bool = False, + graph_only_answer: bool = False, + graph_vector_answer: bool = False, custom_related_information: str = "", answer_prompt: Optional[str] = None, - max_graph_items: int = None, - **_: dict, + max_graph_items: Optional[int] = None, + **kwargs, ): prepared_input.query = query prepared_input.raw_answer = raw_answer @@ -61,7 +61,6 @@ def prepare( "graph_search": graph_search, "max_graph_items": max_graph_items or huge_settings.max_graph_items, } - return def build_flow(self, **kwargs): pipeline = GPipeline() @@ -76,24 +75,14 @@ def build_flow(self, **kwargs): log.info("RAGRawFlow pipeline built successfully") return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): if pipeline is None: - return json.dumps( - {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 - ) - try: - res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - log.info("RAGRawFlow post processing success") - return { - "raw_answer": res.get("raw_answer", ""), - "vector_only_answer": res.get("vector_only_answer", ""), - "graph_only_answer": res.get("graph_only_answer", ""), - "graph_vector_answer": res.get("graph_vector_answer", ""), - } - except Exception as e: - log.error(f"RAGRawFlow post processing failed: {e}") - return json.dumps( - {"error": f"Post processing failed: {str(e)}"}, - ensure_ascii=False, - indent=2, - ) + return {"error": "No pipeline provided"} + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGRawFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } diff --git a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py index c727eacce..150e3162a 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/rag_flow_vector_only.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json from typing import Optional, Literal @@ -28,6 +27,7 @@ from hugegraph_llm.utils.log import log +# pylint: disable=arguments-differ,keyword-arg-before-vararg class RAGVectorOnlyFlow(BaseFlow): """ Workflow for vector-only answering (vector_only_answer) @@ -37,20 +37,20 @@ def prepare( self, prepared_input: WkFlowInput, query: str, - vector_search: bool = None, - graph_search: bool = None, - raw_answer: bool = None, - vector_only_answer: bool = None, - graph_only_answer: bool = None, - graph_vector_answer: bool = None, + vector_search: bool = True, + graph_search: bool = False, + raw_answer: bool = False, + vector_only_answer: bool = True, + graph_only_answer: bool = False, + graph_vector_answer: bool = False, rerank_method: Literal["bleu", "reranker"] = "bleu", near_neighbor_first: bool = False, custom_related_information: str = "", answer_prompt: Optional[str] = None, - max_graph_items: int = None, - topk_return_results: int = None, - vector_dis_threshold: float = None, - **_: dict, + max_graph_items: Optional[int] = None, + topk_return_results: Optional[int] = None, + vector_dis_threshold: Optional[float] = None, + **kwargs, ): prepared_input.query = query prepared_input.vector_search = vector_search @@ -77,7 +77,6 @@ def prepare( "graph_search": graph_search, "max_graph_items": max_graph_items or huge_settings.max_graph_items, } - return def build_flow(self, **kwargs): pipeline = GPipeline() @@ -100,24 +99,14 @@ def build_flow(self, **kwargs): log.info("RAGVectorOnlyFlow pipeline built successfully") return pipeline - def post_deal(self, pipeline=None): + def post_deal(self, pipeline=None, **kwargs): if pipeline is None: - return json.dumps( - {"error": "No pipeline provided"}, ensure_ascii=False, indent=2 - ) - try: - res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() - log.info("RAGVectorOnlyFlow post processing success") - return { - "raw_answer": res.get("raw_answer", ""), - "vector_only_answer": res.get("vector_only_answer", ""), - "graph_only_answer": res.get("graph_only_answer", ""), - "graph_vector_answer": res.get("graph_vector_answer", ""), - } - except Exception as e: - log.error(f"RAGVectorOnlyFlow post processing failed: {e}") - return json.dumps( - {"error": f"Post processing failed: {str(e)}"}, - ensure_ascii=False, - indent=2, - ) + return {"error": "No pipeline provided"} + res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() + log.info("RAGVectorOnlyFlow post processing success") + return { + "raw_answer": res.get("raw_answer", ""), + "vector_only_answer": res.get("vector_only_answer", ""), + "graph_only_answer": res.get("graph_only_answer", ""), + "graph_vector_answer": res.get("graph_vector_answer", ""), + } diff --git a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py index 5afa1bf8e..bdf59d84e 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/scheduler.py @@ -16,11 +16,13 @@ import threading from typing import Dict, Any from PyCGraph import GPipeline, GPipelineManager +from hugegraph_llm.flows import FlowName from hugegraph_llm.flows.build_vector_index import BuildVectorIndexFlow from hugegraph_llm.flows.common import BaseFlow +from hugegraph_llm.flows.build_example_index import BuildExampleIndexFlow from hugegraph_llm.flows.graph_extract import GraphExtractFlow from hugegraph_llm.flows.import_graph_data import ImportGraphDataFlow -from hugegraph_llm.flows.update_vid_embeddings import UpdateVidEmbeddingsFlows +from hugegraph_llm.flows.update_vid_embeddings import UpdateVidEmbeddingsFlow from hugegraph_llm.flows.get_graph_index_info import GetGraphIndexInfoFlow from hugegraph_llm.flows.build_schema import BuildSchemaFlow from hugegraph_llm.flows.prompt_generate import PromptGenerateFlow @@ -34,72 +36,76 @@ class Scheduler: - pipeline_pool: Dict[str, Any] = None + pipeline_pool: Dict[str, Any] max_pipeline: int def __init__(self, max_pipeline: int = 10): self.pipeline_pool = {} # pipeline_pool act as a manager of GPipelineManager which used for pipeline management - self.pipeline_pool["build_vector_index"] = { + self.pipeline_pool[FlowName.BUILD_VECTOR_INDEX] = { "manager": GPipelineManager(), "flow": BuildVectorIndexFlow(), } - self.pipeline_pool["graph_extract"] = { + self.pipeline_pool[FlowName.GRAPH_EXTRACT] = { "manager": GPipelineManager(), "flow": GraphExtractFlow(), } - self.pipeline_pool["import_graph_data"] = { + self.pipeline_pool[FlowName.IMPORT_GRAPH_DATA] = { "manager": GPipelineManager(), "flow": ImportGraphDataFlow(), } - self.pipeline_pool["update_vid_embeddings"] = { + self.pipeline_pool[FlowName.UPDATE_VID_EMBEDDINGS] = { "manager": GPipelineManager(), - "flow": UpdateVidEmbeddingsFlows(), + "flow": UpdateVidEmbeddingsFlow(), } - self.pipeline_pool["get_graph_index_info"] = { + self.pipeline_pool[FlowName.GET_GRAPH_INDEX_INFO] = { "manager": GPipelineManager(), "flow": GetGraphIndexInfoFlow(), } - self.pipeline_pool["build_schema"] = { + self.pipeline_pool[FlowName.BUILD_SCHEMA] = { "manager": GPipelineManager(), "flow": BuildSchemaFlow(), } - self.pipeline_pool["prompt_generate"] = { + self.pipeline_pool[FlowName.PROMPT_GENERATE] = { "manager": GPipelineManager(), "flow": PromptGenerateFlow(), } - self.pipeline_pool["text2gremlin"] = { + self.pipeline_pool[FlowName.TEXT2GREMLIN] = { "manager": GPipelineManager(), "flow": Text2GremlinFlow(), } # New split rag pipelines - self.pipeline_pool["rag_raw"] = { + self.pipeline_pool[FlowName.RAG_RAW] = { "manager": GPipelineManager(), "flow": RAGRawFlow(), } - self.pipeline_pool["rag_vector_only"] = { + self.pipeline_pool[FlowName.RAG_VECTOR_ONLY] = { "manager": GPipelineManager(), "flow": RAGVectorOnlyFlow(), } - self.pipeline_pool["rag_graph_only"] = { + self.pipeline_pool[FlowName.RAG_GRAPH_ONLY] = { "manager": GPipelineManager(), "flow": RAGGraphOnlyFlow(), } - self.pipeline_pool["rag_graph_vector"] = { + self.pipeline_pool[FlowName.RAG_GRAPH_VECTOR] = { "manager": GPipelineManager(), "flow": RAGGraphVectorFlow(), } + self.pipeline_pool[FlowName.BUILD_EXAMPLES_INDEX] = { + "manager": GPipelineManager(), + "flow": BuildExampleIndexFlow(), + } self.max_pipeline = max_pipeline # TODO: Implement Agentic Workflow def agentic_flow(self): pass - def schedule_flow(self, flow: str, *args, **kwargs): - if flow not in self.pipeline_pool: - raise ValueError(f"Unsupported workflow {flow}") - manager: GPipelineManager = self.pipeline_pool[flow]["manager"] - flow: BaseFlow = self.pipeline_pool[flow]["flow"] + def schedule_flow(self, flow_name: str, *args, **kwargs): + if flow_name not in self.pipeline_pool: + raise ValueError(f"Unsupported workflow {flow_name}") + manager: GPipelineManager = self.pipeline_pool[flow_name]["manager"] + flow: BaseFlow = self.pipeline_pool[flow_name]["flow"] pipeline: GPipeline = manager.fetch() if pipeline is None: # call coresponding flow_func to create new workflow @@ -111,13 +117,14 @@ def schedule_flow(self, flow: str, *args, **kwargs): raise RuntimeError(error_msg) status = pipeline.run() if status.isErr(): + manager.add(pipeline) error_msg = f"Error in flow execution: {status.getInfo()}" log.error(error_msg) raise RuntimeError(error_msg) res = flow.post_deal(pipeline) manager.add(pipeline) return res - else: + try: # fetch pipeline & prepare input for flow prepared_input = pipeline.getGParamWithNoEmpty("wkflow_input") flow.prepare(prepared_input, *args, **kwargs) @@ -127,49 +134,46 @@ def schedule_flow(self, flow: str, *args, **kwargs): log.error(error_msg) raise RuntimeError(error_msg) res = flow.post_deal(pipeline) + finally: manager.release(pipeline) - return res + return res - async def schedule_stream_flow(self, flow: str, *args, **kwargs): - if flow not in self.pipeline_pool: - raise ValueError(f"Unsupported workflow {flow}") - manager: GPipelineManager = self.pipeline_pool[flow]["manager"] - flow: BaseFlow = self.pipeline_pool[flow]["flow"] + async def schedule_stream_flow(self, flow_name: str, *args, **kwargs): + if flow_name not in self.pipeline_pool: + raise ValueError(f"Unsupported workflow {flow_name}") + manager: GPipelineManager = self.pipeline_pool[flow_name]["manager"] + flow: BaseFlow = self.pipeline_pool[flow_name]["flow"] pipeline: GPipeline = manager.fetch() if pipeline is None: # call coresponding flow_func to create new workflow pipeline = flow.build_flow(*args, **kwargs) - try: - pipeline.getGParamWithNoEmpty("wkflow_input").stream = True - status = pipeline.init() - if status.isErr(): - error_msg = f"Error in flow init: {status.getInfo()}" - log.error(error_msg) - raise RuntimeError(error_msg) - status = pipeline.run() - if status.isErr(): - error_msg = f"Error in flow execution: {status.getInfo()}" - log.error(error_msg) - raise RuntimeError(error_msg) - async for res in flow.post_deal_stream(pipeline): - yield res - finally: + pipeline.getGParamWithNoEmpty("wkflow_input").stream = True + status = pipeline.init() + if status.isErr(): + error_msg = f"Error in flow init: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + status = pipeline.run() + if status.isErr(): manager.add(pipeline) - else: - try: - # fetch pipeline & prepare input for flow - prepared_input: WkFlowInput = pipeline.getGParamWithNoEmpty( - "wkflow_input" - ) - prepared_input.stream = True - flow.prepare(prepared_input, *args, **kwargs) - status = pipeline.run() - if status.isErr(): - raise RuntimeError(f"Error in flow execution {status.getInfo()}") - async for res in flow.post_deal_stream(pipeline): - yield res - finally: - manager.release(pipeline) + error_msg = f"Error in flow execution: {status.getInfo()}" + log.error(error_msg) + raise RuntimeError(error_msg) + async for res in flow.post_deal_stream(pipeline): + yield res + manager.add(pipeline) + try: + # fetch pipeline & prepare input for flow + prepared_input: WkFlowInput = pipeline.getGParamWithNoEmpty("wkflow_input") + prepared_input.stream = True + flow.prepare(prepared_input, *args, **kwargs) + status = pipeline.run() + if status.isErr(): + raise RuntimeError(f"Error in flow execution {status.getInfo()}") + async for res in flow.post_deal_stream(pipeline): + yield res + finally: + manager.release(pipeline) class SchedulerSingleton: diff --git a/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py b/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py index e9ba4276c..1ae5662cb 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/text2gremlin.py @@ -13,18 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, List, Optional + from PyCGraph import GPipeline from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState from hugegraph_llm.nodes.hugegraph_node.schema import SchemaNode -from hugegraph_llm.nodes.index_node.gremlin_example_index_query import GremlinExampleIndexQueryNode +from hugegraph_llm.nodes.index_node.gremlin_example_index_query import ( + GremlinExampleIndexQueryNode, +) from hugegraph_llm.nodes.llm_node.text2gremlin import Text2GremlinNode from hugegraph_llm.nodes.hugegraph_node.gremlin_execute import GremlinExecuteNode -from typing import Any, Dict, List, Optional - +# pylint: disable=arguments-differ,keyword-arg-before-vararg class Text2GremlinFlow(BaseFlow): def __init__(self): pass @@ -37,6 +40,7 @@ def prepare( schema_input: str, gremlin_prompt_input: Optional[str], requested_outputs: Optional[List[str]], + **kwargs, ): # sanitize example_num to [0,10], fallback to 2 if invalid if not isinstance(example_num, int): @@ -63,7 +67,6 @@ def prepare( prepared_input.schema = schema_input prepared_input.gremlin_prompt = gremlin_prompt_input prepared_input.requested_outputs = req - return def build_flow( self, @@ -72,6 +75,7 @@ def build_flow( schema_input: str, gremlin_prompt_input: Optional[str] = None, requested_outputs: Optional[List[str]] = None, + **kwargs, ): pipeline = GPipeline() @@ -100,7 +104,7 @@ def build_flow( return pipeline - def post_deal(self, pipeline=None) -> Dict[str, Any]: + def post_deal(self, pipeline=None, **kwargs) -> Dict[str, Any]: state = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() # 始终返回 5 个标准键,避免前端因过滤异常看不到字段 return { diff --git a/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py b/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py index b3f0d9923..216f35618 100644 --- a/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py +++ b/hugegraph-llm/src/hugegraph_llm/flows/update_vid_embeddings.py @@ -13,18 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus, GPipeline -from hugegraph_llm.flows.common import BaseFlow, WkFlowInput +from PyCGraph import GPipeline + +from hugegraph_llm.flows.common import BaseFlow from hugegraph_llm.nodes.hugegraph_node.fetch_graph_data import FetchGraphDataNode from hugegraph_llm.nodes.index_node.build_semantic_index import BuildSemanticIndexNode -from hugegraph_llm.state.ai_state import WkFlowState +from hugegraph_llm.state.ai_state import WkFlowState, WkFlowInput -class UpdateVidEmbeddingsFlows(BaseFlow): - def prepare(self, prepared_input: WkFlowInput): - return CStatus() +# pylint: disable=arguments-differ,keyword-arg-before-vararg +class UpdateVidEmbeddingsFlow(BaseFlow): + def prepare(self, prepared_input: WkFlowInput, **kwargs): + pass - def build_flow(self): + def build_flow(self, **kwargs): pipeline = GPipeline() prepared_input = WkFlowInput() # prepare input data @@ -40,7 +42,7 @@ def build_flow(self): return pipeline - def post_deal(self, pipeline): + def post_deal(self, pipeline, **kwargs): res = pipeline.getGParamWithNoEmpty("wkflow_state").to_json() removed_num = res.get("removed_vid_vector_num", 0) added_num = res.get("added_vid_vector_num", 0) diff --git a/hugegraph-llm/src/hugegraph_llm/flows/utils.py b/hugegraph-llm/src/hugegraph_llm/flows/utils.py deleted file mode 100644 index b4ba05c84..000000000 --- a/hugegraph-llm/src/hugegraph_llm/flows/utils.py +++ /dev/null @@ -1,34 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json - -from hugegraph_llm.state.ai_state import WkFlowInput -from hugegraph_llm.utils.log import log - - -def prepare_schema(prepared_input: WkFlowInput, schema): - schema = schema.strip() - if schema.startswith("{"): - try: - schema = json.loads(schema) - prepared_input.schema = schema - except json.JSONDecodeError as exc: - log.error("Invalid JSON format in schema. Please check it again.") - raise ValueError("Invalid JSON format in schema.") from exc - else: - log.info("Get schema '%s' from graphdb.", schema) - prepared_input.graph_name = schema - return diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py index 3ad50b3ec..de04dff87 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py @@ -29,27 +29,27 @@ } -def get_embedding(llm_settings: LLMConfig): - if llm_settings.embedding_type == "openai": +def get_embedding(llm_configs: LLMConfig): + if llm_configs.embedding_type == "openai": return OpenAIEmbedding( - model_name=llm_settings.openai_embedding_model, - api_key=llm_settings.openai_embedding_api_key, - api_base=llm_settings.openai_embedding_api_base, + model_name=llm_configs.openai_embedding_model, + api_key=llm_configs.openai_embedding_api_key, + api_base=llm_configs.openai_embedding_api_base, ) - if llm_settings.embedding_type == "ollama/local": + if llm_configs.embedding_type == "ollama/local": return OllamaEmbedding( - model_name=llm_settings.ollama_embedding_model, - host=llm_settings.ollama_embedding_host, - port=llm_settings.ollama_embedding_port, + model_name=llm_configs.ollama_embedding_model, + host=llm_configs.ollama_embedding_host, + port=llm_configs.ollama_embedding_port, ) - if llm_settings.embedding_type == "litellm": + if llm_configs.embedding_type == "litellm": return LiteLLMEmbedding( - model_name=llm_settings.litellm_embedding_model, - api_key=llm_settings.litellm_embedding_api_key, - api_base=llm_settings.litellm_embedding_api_base, + model_name=llm_configs.litellm_embedding_model, + api_key=llm_configs.litellm_embedding_api_key, + api_base=llm_configs.litellm_embedding_api_base, ) - raise Exception("embedding type is not supported !") + raise ValueError("embedding type is not supported !") class Embeddings: diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py index 9121fca09..a13641db0 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py @@ -22,74 +22,74 @@ from hugegraph_llm.config import llm_settings -def get_chat_llm(llm_settings: LLMConfig): - if llm_settings.chat_llm_type == "openai": +def get_chat_llm(llm_configs: LLMConfig): + if llm_configs.chat_llm_type == "openai": return OpenAIClient( - api_key=llm_settings.openai_chat_api_key, - api_base=llm_settings.openai_chat_api_base, - model_name=llm_settings.openai_chat_language_model, - max_tokens=llm_settings.openai_chat_tokens, + api_key=llm_configs.openai_chat_api_key, + api_base=llm_configs.openai_chat_api_base, + model_name=llm_configs.openai_chat_language_model, + max_tokens=llm_configs.openai_chat_tokens, ) - if llm_settings.chat_llm_type == "ollama/local": + if llm_configs.chat_llm_type == "ollama/local": return OllamaClient( - model=llm_settings.ollama_chat_language_model, - host=llm_settings.ollama_chat_host, - port=llm_settings.ollama_chat_port, + model=llm_configs.ollama_chat_language_model, + host=llm_configs.ollama_chat_host, + port=llm_configs.ollama_chat_port, ) - if llm_settings.chat_llm_type == "litellm": + if llm_configs.chat_llm_type == "litellm": return LiteLLMClient( - api_key=llm_settings.litellm_chat_api_key, - api_base=llm_settings.litellm_chat_api_base, - model_name=llm_settings.litellm_chat_language_model, - max_tokens=llm_settings.litellm_chat_tokens, + api_key=llm_configs.litellm_chat_api_key, + api_base=llm_configs.litellm_chat_api_base, + model_name=llm_configs.litellm_chat_language_model, + max_tokens=llm_configs.litellm_chat_tokens, ) raise Exception("chat llm type is not supported !") -def get_extract_llm(llm_settings: LLMConfig): - if llm_settings.extract_llm_type == "openai": +def get_extract_llm(llm_configs: LLMConfig): + if llm_configs.extract_llm_type == "openai": return OpenAIClient( - api_key=llm_settings.openai_extract_api_key, - api_base=llm_settings.openai_extract_api_base, - model_name=llm_settings.openai_extract_language_model, - max_tokens=llm_settings.openai_extract_tokens, + api_key=llm_configs.openai_extract_api_key, + api_base=llm_configs.openai_extract_api_base, + model_name=llm_configs.openai_extract_language_model, + max_tokens=llm_configs.openai_extract_tokens, ) - if llm_settings.extract_llm_type == "ollama/local": + if llm_configs.extract_llm_type == "ollama/local": return OllamaClient( - model=llm_settings.ollama_extract_language_model, - host=llm_settings.ollama_extract_host, - port=llm_settings.ollama_extract_port, + model=llm_configs.ollama_extract_language_model, + host=llm_configs.ollama_extract_host, + port=llm_configs.ollama_extract_port, ) - if llm_settings.extract_llm_type == "litellm": + if llm_configs.extract_llm_type == "litellm": return LiteLLMClient( - api_key=llm_settings.litellm_extract_api_key, - api_base=llm_settings.litellm_extract_api_base, - model_name=llm_settings.litellm_extract_language_model, - max_tokens=llm_settings.litellm_extract_tokens, + api_key=llm_configs.litellm_extract_api_key, + api_base=llm_configs.litellm_extract_api_base, + model_name=llm_configs.litellm_extract_language_model, + max_tokens=llm_configs.litellm_extract_tokens, ) raise Exception("extract llm type is not supported !") -def get_text2gql_llm(llm_settings: LLMConfig): - if llm_settings.text2gql_llm_type == "openai": +def get_text2gql_llm(llm_configs: LLMConfig): + if llm_configs.text2gql_llm_type == "openai": return OpenAIClient( - api_key=llm_settings.openai_text2gql_api_key, - api_base=llm_settings.openai_text2gql_api_base, - model_name=llm_settings.openai_text2gql_language_model, - max_tokens=llm_settings.openai_text2gql_tokens, + api_key=llm_configs.openai_text2gql_api_key, + api_base=llm_configs.openai_text2gql_api_base, + model_name=llm_configs.openai_text2gql_language_model, + max_tokens=llm_configs.openai_text2gql_tokens, ) - if llm_settings.text2gql_llm_type == "ollama/local": + if llm_configs.text2gql_llm_type == "ollama/local": return OllamaClient( - model=llm_settings.ollama_text2gql_language_model, - host=llm_settings.ollama_text2gql_host, - port=llm_settings.ollama_text2gql_port, + model=llm_configs.ollama_text2gql_language_model, + host=llm_configs.ollama_text2gql_host, + port=llm_configs.ollama_text2gql_port, ) - if llm_settings.text2gql_llm_type == "litellm": + if llm_configs.text2gql_llm_type == "litellm": return LiteLLMClient( - api_key=llm_settings.litellm_text2gql_api_key, - api_base=llm_settings.litellm_text2gql_api_base, - model_name=llm_settings.litellm_text2gql_language_model, - max_tokens=llm_settings.litellm_text2gql_tokens, + api_key=llm_configs.litellm_text2gql_api_key, + api_base=llm_configs.litellm_text2gql_api_base, + model_name=llm_configs.litellm_text2gql_language_model, + max_tokens=llm_configs.litellm_text2gql_tokens, ) raise Exception("text2gql llm type is not supported !") @@ -173,4 +173,8 @@ def get_text2gql_llm(self): if __name__ == "__main__": client = LLMs().get_chat_llm() print(client.generate(prompt="What is the capital of China?")) - print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}])) + print( + client.generate( + messages=[{"role": "user", "content": "What is the capital of China?"}] + ) + ) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py index f90167305..d7e53d4b8 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/base_node.py @@ -13,14 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Optional from PyCGraph import GNode, CStatus from hugegraph_llm.nodes.util import init_context from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState +from hugegraph_llm.utils.log import log class BaseNode(GNode): - context: WkFlowState = None - wk_input: WkFlowInput = None + """ + Base class for workflow nodes, providing context management and operation scheduling. + + All custom nodes should inherit from this class and implement the operator_schedule method. + + Attributes: + context: Shared workflow state + wk_input: Workflow input parameters + """ + + context: Optional[WkFlowState] = None + wk_input: Optional[WkFlowInput] = None def init(self): return init_context(self) @@ -30,6 +42,8 @@ def node_init(self): Node initialization method, can be overridden by subclasses. Returns a CStatus object indicating whether initialization succeeded. """ + if self.wk_input is None or self.context is None: + return CStatus(-1, "wk_input or context not initialized") if self.wk_input.data_json is not None: self.context.assign_from_json(self.wk_input.data_json) self.wk_input.data_json = None @@ -43,6 +57,8 @@ def run(self): sts = self.node_init() if sts.isErr(): return sts + if self.context is None: + return CStatus(-1, "Context not initialized") self.context.lock() try: data_json = self.context.to_json() @@ -51,24 +67,35 @@ def run(self): try: res = self.operator_schedule(data_json) - except Exception as exc: + except (ValueError, TypeError, KeyError, NotImplementedError) as exc: import traceback node_info = f"Node type: {type(self).__name__}, Node object: {self}" err_msg = f"Node failed: {exc}\n{node_info}\n{traceback.format_exc()}" return CStatus(-1, err_msg) + # For unexpected exceptions, re-raise to let them propagate or be caught elsewhere self.context.lock() try: - if isinstance(res, dict): + if res is not None and isinstance(res, dict): self.context.assign_from_json(res) + elif res is not None: + log.warning("operator_schedule returned non-dict type: %s", type(res)) finally: self.context.unlock() return CStatus() - def operator_schedule(self, data_json): + def operator_schedule(self, data_json) -> Optional[Dict]: """ - Interface for scheduling the operator, can be overridden by subclasses. - Returns a CStatus object indicating whether scheduling succeeded. + Operation scheduling method that must be implemented by subclasses. + + Args: + data_json: Context serialized as JSON data + + Returns: + Dictionary of processing results, or None to indicate no update + + Raises: + NotImplementedError: If the subclass has not implemented this method """ - pass + raise NotImplementedError("Subclasses must implement operator_schedule") diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py index 78f53e231..c718086aa 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/common_node/merge_rerank_node.py @@ -52,8 +52,8 @@ def node_init(self): topk_return_results=topk_return_results, ) return super().node_init() - except Exception as e: - log.error(f"Failed to initialize MergeRerankNode: {e}") + except ValueError as e: + log.error("Failed to initialize MergeRerankNode: %s", e) from PyCGraph import CStatus return CStatus(-1, f"MergeRerankNode initialization failed: {e}") @@ -72,12 +72,14 @@ def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: merged_count = len(result.get("merged_result", [])) log.info( - f"Merge and rerank completed: {vector_count} vector results, " - f"{graph_count} graph results, {merged_count} merged results" + "Merge and rerank completed: %d vector results, %d graph results, %d merged results", + vector_count, + graph_count, + merged_count, ) return result - except Exception as e: - log.error(f"Merge and rerank failed: {e}") + except ValueError as e: + log.error("Merge and rerank failed: %s", e) return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py index f71bd7bd5..883cc909d 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/document_node/chunk_split.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from hugegraph_llm.nodes.base_node import BaseNode from PyCGraph import CStatus +from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py index 99b428e5e..6e9dd01ad 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/fetch_graph_data.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState @@ -21,11 +23,12 @@ class FetchGraphDataNode(BaseNode): fetch_graph_data_op: FetchGraphData - context: WkFlowState = None - wk_input: WkFlowInput = None + context: Optional[WkFlowState] = None + wk_input: Optional[WkFlowInput] = None def node_init(self): - self.fetch_graph_data_op = FetchGraphData(get_hg_client()) + client = get_hg_client() + self.fetch_graph_data_op = FetchGraphData(client) return super().node_init() def operator_schedule(self, data_json): diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py index ae65ccb33..c9d62a9d5 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/graph_query_node.py @@ -13,12 +13,62 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus -from typing import Dict, Any +import json +from typing import Dict, Any, Tuple, List, Set, Optional + from hugegraph_llm.nodes.base_node import BaseNode -from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.operators.operator_list import OperatorList from hugegraph_llm.utils.log import log +from pyhugegraph.client import PyHugeClient + +# TODO: remove 'as('subj)' step +VERTEX_QUERY_TPL = "g.V({keywords}).limit(8).as('subj').toList()" + +# TODO: we could use a simpler query (like kneighbor-api to get the edges) +# TODO: test with profile()/explain() to speed up the query +VID_QUERY_NEIGHBOR_TPL = """\ +g.V({keywords}) +.repeat( + bothE({edge_labels}).limit({edge_limit}).otherV().dedup() +).times({max_deep}).emit() +.simplePath() +.path() +.by(project('label', 'id', 'props') + .by(label()) + .by(id()) + .by(valueMap().by(unfold())) +) +.by(project('label', 'inV', 'outV', 'props') + .by(label()) + .by(inV().id()) + .by(outV().id()) + .by(valueMap().by(unfold())) +) +.limit({max_items}) +.toList() +""" + +PROPERTY_QUERY_NEIGHBOR_TPL = """\ +g.V().has('{prop}', within({keywords})) +.repeat( + bothE({edge_labels}).limit({edge_limit}).otherV().dedup() +).times({max_deep}).emit() +.simplePath() +.path() +.by(project('label', 'props') + .by(label()) + .by(valueMap().by(unfold())) +) +.by(project('label', 'inV', 'outV', 'props') + .by(label()) + .by(inV().values('{prop}')) + .by(outV().values('{prop}')) + .by(valueMap().by(unfold())) +) +.limit({max_items}) +.toList() +""" class GraphQueryNode(BaseNode): @@ -26,45 +76,395 @@ class GraphQueryNode(BaseNode): Graph query node, responsible for retrieving relevant information from the graph database. """ - graph_rag_query: GraphRAGQuery + _client: Optional[PyHugeClient] = None + _max_deep: Optional[int] = None + _max_items: Optional[int] = None + _prop_to_match: Optional[str] = None + _num_gremlin_generate_example: int = -1 + gremlin_prompt: str = "" + _limit_property: bool = False + _max_v_prop_len: int = 2048 + _max_e_prop_len: int = 256 + _schema: str = "" + operator_list: Optional[OperatorList] = None def node_init(self): """ Initialize the graph query operator. """ + self._client: PyHugeClient = PyHugeClient( + url=huge_settings.graph_url, + graph=huge_settings.graph_name, + user=huge_settings.graph_user, + pwd=huge_settings.graph_pwd, + graphspace=huge_settings.graph_space, + ) + self._max_deep = self.wk_input.max_deep or 2 + self._max_items = self.wk_input.max_graph_items or huge_settings.max_graph_items + self._prop_to_match = self.wk_input.prop_to_match + self._num_gremlin_generate_example = ( + self.wk_input.gremlin_tmpl_num + if self.wk_input.gremlin_tmpl_num is not None + else -1 + ) + self.gremlin_prompt = ( + self.wk_input.gremlin_prompt or prompt.gremlin_generate_prompt + ) + self._limit_property = huge_settings.limit_property.lower() == "true" + self._max_v_prop_len = self.wk_input.max_v_prop_len or 2048 + self._max_e_prop_len = self.wk_input.max_e_prop_len or 256 + self._schema = "" + self.operator_list = OperatorList(None, None) + + return super().node_init() + + # TODO: move this method to a util file for reuse (remove self param) + def init_client(self, context): + """Initialize the HugeGraph client from context or default settings.""" + # pylint: disable=R0915 (too-many-statements) + if self._client is None: + if isinstance(context.get("graph_client"), PyHugeClient): + self._client = context["graph_client"] + else: + url = context.get("url") or "http://localhost:8080" + graph = context.get("graph") or "hugegraph" + user = context.get("user") or "admin" + pwd = context.get("pwd") or "admin" + gs = context.get("graphspace") or None + self._client = PyHugeClient(url, graph, user, pwd, gs) + assert self._client is not None, "No valid graph to search." + + def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: + query = context["query"] + vertices = context.get("match_vids") + query_embedding = context.get("query_embedding") + + self.operator_list.clear() + self.operator_list.example_index_query( + num_examples=self._num_gremlin_generate_example + ) + gremlin_response = self.operator_list.gremlin_generate_synthesize( + context["simple_schema"], + vertices=vertices, + gremlin_prompt=self.gremlin_prompt, + ).run(query=query, query_embedding=query_embedding) + if self._num_gremlin_generate_example > 0: + gremlin = gremlin_response["result"] + else: + gremlin = gremlin_response["raw_result"] + log.info("Generated gremlin: %s", gremlin) + context["gremlin"] = gremlin try: - graph_name = huge_settings.graph_name - if not graph_name: - return CStatus(-1, "graph_name is required in wk_input") + result = self._client.gremlin().exec(gremlin=gremlin)["data"] + if result == [None]: + result = [] + context["graph_result"] = [ + json.dumps(item, ensure_ascii=False) for item in result + ] + if context["graph_result"]: + context["graph_result_flag"] = 1 + context["graph_context_head"] = ( + f"The following are graph query result " + f"from gremlin query `{gremlin}`.\n" + ) + except Exception as e: # pylint: disable=broad-except,broad-exception-caught + log.error(e) + context["graph_result"] = [] + return context - max_deep = self.wk_input.max_deep or 2 - max_graph_items = ( - self.wk_input.max_graph_items or huge_settings.max_graph_items - ) - max_v_prop_len = self.wk_input.max_v_prop_len or 2048 - max_e_prop_len = self.wk_input.max_e_prop_len or 256 - prop_to_match = self.wk_input.prop_to_match - num_gremlin_generate_example = self.wk_input.gremlin_tmpl_num or -1 - gremlin_prompt = ( - self.wk_input.gremlin_prompt or prompt.gremlin_generate_prompt + def _limit_property_query( + self, value: Optional[str], item_type: str + ) -> Optional[str]: + # NOTE: we skip the filter for list/set type (e.g., list of string, add it if needed) + if not self._limit_property or not isinstance(value, str): + return value + + max_len = self._max_v_prop_len if item_type == "v" else self._max_e_prop_len + return value[:max_len] if value else value + + def _process_vertex( + self, + item: Any, + flat_rel: str, + node_cache: Set[str], + prior_edge_str_len: int, + depth: int, + nodes_with_degree: List[str], + use_id_to_match: bool, + v_cache: Set[str], + ) -> Tuple[str, int, int]: + matched_str = ( + item["id"] if use_id_to_match else item["props"][self._prop_to_match] + ) + if matched_str in node_cache: + flat_rel = flat_rel[:-prior_edge_str_len] + return flat_rel, prior_edge_str_len, depth + + node_cache.add(matched_str) + props_str = ", ".join( + f"{k}: {self._limit_property_query(v, 'v')}" + for k, v in item["props"].items() + if v + ) + + # TODO: we may remove label id or replace with label name + if matched_str in v_cache: + node_str = matched_str + else: + v_cache.add(matched_str) + node_str = f"{item['id']}{{{props_str}}}" + + flat_rel += node_str + nodes_with_degree.append(node_str) + depth += 1 + return flat_rel, prior_edge_str_len, depth + + def _process_edge( + self, + item: Any, + path_str: str, + raw_flat_rel: List[Any], + i: int, + use_id_to_match: bool, + e_cache: Set[Tuple[str, str, str]], + ) -> Tuple[str, int]: + props_str = ", ".join( + f"{k}: {self._limit_property_query(v, 'e')}" + for k, v in item["props"].items() + if v + ) + props_str = f"{{{props_str}}}" if props_str else "" + prev_matched_str = ( + raw_flat_rel[i - 1]["id"] + if use_id_to_match + else (raw_flat_rel)[i - 1]["props"][self._prop_to_match] + ) + + edge_key = (item["inV"], item["label"], item["outV"]) + if edge_key not in e_cache: + e_cache.add(edge_key) + edge_label = f"{item['label']}{props_str}" + else: + edge_label = item["label"] + + edge_str = ( + f"--[{edge_label}]-->" + if item["outV"] == prev_matched_str + else f"<--[{edge_label}]--" + ) + path_str += edge_str + prior_edge_str_len = len(edge_str) + return path_str, prior_edge_str_len + + def _process_path( + self, + path: Any, + use_id_to_match: bool, + v_cache: Set[str], + e_cache: Set[Tuple[str, str, str]], + ) -> Tuple[str, List[str]]: + flat_rel = "" + raw_flat_rel = path["objects"] + assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should be odd." + + node_cache = set() + prior_edge_str_len = 0 + depth = 0 + nodes_with_degree = [] + + for i, item in enumerate(raw_flat_rel): + if i % 2 == 0: + # Process each vertex + flat_rel, prior_edge_str_len, depth = self._process_vertex( + item, + flat_rel, + node_cache, + prior_edge_str_len, + depth, + nodes_with_degree, + use_id_to_match, + v_cache, + ) + else: + # Process each edge + flat_rel, prior_edge_str_len = self._process_edge( + item, flat_rel, raw_flat_rel, i, use_id_to_match, e_cache + ) + + return flat_rel, nodes_with_degree + + def _update_vertex_degree_list( + self, vertex_degree_list: List[Set[str]], nodes_with_degree: List[str] + ) -> None: + for depth, node_str in enumerate(nodes_with_degree): + if depth >= len(vertex_degree_list): + vertex_degree_list.append(set()) + vertex_degree_list[depth].add(node_str) + + def _format_graph_query_result( + self, query_paths + ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: + use_id_to_match = self._prop_to_match is None + subgraph = set() + subgraph_with_degree = {} + vertex_degree_list: List[Set[str]] = [] + v_cache: Set[str] = set() + e_cache: Set[Tuple[str, str, str]] = set() + + for path in query_paths: + # 1. Process each path + path_str, vertex_with_degree = self._process_path( + path, use_id_to_match, v_cache, e_cache ) + subgraph.add(path_str) + subgraph_with_degree[path_str] = vertex_with_degree + # 2. Update vertex degree list + self._update_vertex_degree_list(vertex_degree_list, vertex_with_degree) + + return subgraph, vertex_degree_list, subgraph_with_degree + + def _get_graph_schema(self, refresh: bool = False) -> str: + if self._schema and not refresh: + return self._schema + + schema = self._client.schema() + vertex_schema = schema.getVertexLabels() + edge_schema = schema.getEdgeLabels() + relationships = schema.getRelations() + + self._schema = ( + f"Vertex properties: {vertex_schema}\n" + f"Edge properties: {edge_schema}\n" + f"Relationships: {relationships}\n" + ) + log.debug("Link(Relation): %s", relationships) + return self._schema + + @staticmethod + def _extract_label_names( + source: str, head: str = "name: ", tail: str = ", " + ) -> List[str]: + result = [] + for s in source.split(head): + end = s.find(tail) + label = s[:end] + if label: + result.append(label) + return result + + def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]: + schema = self._get_graph_schema() + vertex_props_str, edge_props_str = schema.split("\n")[:2] + # TODO: rename to vertex (also need update in the schema) + vertex_props_str = ( + vertex_props_str[len("Vertex properties: ") :].strip("[").strip("]") + ) + edge_props_str = ( + edge_props_str[len("Edge properties: ") :].strip("[").strip("]") + ) + vertex_labels = self._extract_label_names(vertex_props_str) + edge_labels = self._extract_label_names(edge_props_str) + return vertex_labels, edge_labels + + def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]: + knowledge = set() + for item in query_result: + props_str = ", ".join(f"{k}: {v}" for k, v in item["properties"].items()) + node_str = f"{item['id']}{{{props_str}}}" + knowledge.add(node_str) + return knowledge - # Initialize GraphRAGQuery operator - self.graph_rag_query = GraphRAGQuery( - max_deep=max_deep, - max_graph_items=max_graph_items, - max_v_prop_len=max_v_prop_len, - max_e_prop_len=max_e_prop_len, - prop_to_match=prop_to_match, - num_gremlin_generate_example=num_gremlin_generate_example, - gremlin_prompt=gremlin_prompt, + def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: + # 1. Extract params from context + matched_vids = context.get("match_vids") + if isinstance(context.get("max_deep"), int): + self._max_deep = context["max_deep"] + if isinstance(context.get("max_items"), int): + self._max_items = context["max_items"] + if isinstance(context.get("prop_to_match"), str): + self._prop_to_match = context["prop_to_match"] + + # 2. Extract edge_labels from graph schema + _, edge_labels = self._extract_labels_from_schema() + edge_labels_str = ",".join("'" + label + "'" for label in edge_labels) + # TODO: enhance the limit logic later + edge_limit_amount = len(edge_labels) * huge_settings.edge_limit_pre_label + + use_id_to_match = self._prop_to_match is None + if use_id_to_match: + if not matched_vids: + return context + + gremlin_query = VERTEX_QUERY_TPL.format(keywords=matched_vids) + vertexes = self._client.gremlin().exec(gremlin=gremlin_query)["data"] + log.debug("Vids gremlin query: %s", gremlin_query) + + vertex_knowledge = self._format_graph_from_vertex(query_result=vertexes) + paths: List[Any] = [] + # TODO: use generator or asyncio to speed up the query logic + for matched_vid in matched_vids: + gremlin_query = VID_QUERY_NEIGHBOR_TPL.format( + keywords=f"'{matched_vid}'", + max_deep=self._max_deep, + edge_labels=edge_labels_str, + edge_limit=edge_limit_amount, + max_items=self._max_items, + ) + log.debug("Kneighbor gremlin query: %s", gremlin_query) + paths.extend(self._client.gremlin().exec(gremlin=gremlin_query)["data"]) + + ( + graph_chain_knowledge, + vertex_degree_list, + knowledge_with_degree, + ) = self._format_graph_query_result(query_paths=paths) + + # TODO: we may need to optimize the logic here with global deduplication (may lack some single vertex) + if not graph_chain_knowledge: + graph_chain_knowledge.update(vertex_knowledge) + if vertex_degree_list: + vertex_degree_list[0].update(vertex_knowledge) + else: + vertex_degree_list.append(vertex_knowledge) + else: + # WARN: When will the query enter here? + keywords = context.get("keywords") + assert keywords, "No related property(keywords) for graph query." + keywords_str = ",".join("'" + kw + "'" for kw in keywords) + gremlin_query = PROPERTY_QUERY_NEIGHBOR_TPL.format( + prop=self._prop_to_match, + keywords=keywords_str, + edge_labels=edge_labels_str, + edge_limit=edge_limit_amount, + max_deep=self._max_deep, + max_items=self._max_items, + ) + log.warning( + "Unable to find vid, downgraded to property query, please confirm if it meets expectation." ) - return super().node_init() - except Exception as e: - log.error(f"Failed to initialize GraphQueryNode: {e}") + paths: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)[ + "data" + ] + ( + graph_chain_knowledge, + vertex_degree_list, + knowledge_with_degree, + ) = self._format_graph_query_result(query_paths=paths) - return CStatus(-1, f"GraphQueryNode initialization failed: {e}") + context["graph_result"] = list(graph_chain_knowledge) + if context["graph_result"]: + context["graph_result_flag"] = 0 + context["vertex_degree_list"] = [ + list(vertex_degree) for vertex_degree in vertex_degree_list + ] + context["knowledge_with_degree"] = knowledge_with_degree + context["graph_context_head"] = ( + f"The following are graph knowledge in {self._max_deep} depth, e.g:\n" + "`vertexA--[links]-->vertexB<--[links]--vertexC ...`" + "extracted based on key entities as subject:\n" + ) + return context def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: """ @@ -79,15 +479,31 @@ def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: return data_json # Execute the graph query (assuming schema and semantic query have been completed in previous nodes) - graph_result = self.graph_rag_query.run(data_json) - data_json.update(graph_result) + self.init_client(data_json) + + # initial flag: -1 means no result, 0 means subgraph query, 1 means gremlin query + data_json["graph_result_flag"] = -1 + # 1. Try to perform a query based on the generated gremlin + if self._num_gremlin_generate_example >= 0: + data_json = self._gremlin_generate_query(data_json) + # 2. Try to perform a query based on subgraph-search if the previous query failed + if not data_json.get("graph_result"): + data_json = self._subgraph_query(data_json) + + if data_json.get("graph_result"): + log.debug( + "Knowledge from Graph:\n%s", "\n".join(data_json["graph_result"]) + ) + else: + log.debug("No Knowledge Extracted from Graph") log.info( - f"Graph query completed, found {len(data_json.get('graph_result', []))} results" + "Graph query completed, found %d results", + len(data_json.get("graph_result", [])), ) return data_json - except Exception as e: - log.error(f"Graph query failed: {e}") + except Exception as e: # pylint: disable=broad-except,broad-exception-caught + log.error("Graph query failed: %s", e) return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py index 3face9d63..26d74c5d9 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/hugegraph_node/schema.py @@ -15,6 +15,7 @@ import json +from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.common_op.check_schema import CheckSchema from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager @@ -38,23 +39,23 @@ def _import_schema( ): if from_hugegraph: return SchemaManager(from_hugegraph) - elif from_user_defined: + if from_user_defined: return CheckSchema(from_user_defined) - elif from_extraction: + if from_extraction: raise NotImplementedError("Not implemented yet") - else: - raise ValueError("No input data / invalid schema type") + raise ValueError("No input data / invalid schema type") def node_init(self): - self.schema = self.wk_input.schema - self.schema = self.schema.strip() + if self.wk_input.schema is None: + return CStatus(-1, "Schema message is required in SchemaNode") + self.schema = self.wk_input.schema.strip() if self.schema.startswith("{"): try: schema = json.loads(self.schema) self.check_schema = self._import_schema(from_user_defined=schema) except json.JSONDecodeError as exc: log.error("Invalid JSON format in schema. Please check it again.") - raise ValueError("Invalid JSON format in schema.") from exc + return CStatus(-1, f"Invalid JSON format in schema. {exc}") else: log.info("Get schema '%s' from graphdb.", self.schema) self.schema_manager = self._import_schema(from_hugegraph=self.schema) @@ -63,11 +64,6 @@ def node_init(self): def operator_schedule(self, data_json): log.debug("SchemaNode input state: %s", data_json) if self.schema.startswith("{"): - try: - return self.check_schema.run(data_json) - except json.JSONDecodeError as exc: - log.error("Invalid JSON format in schema. Please check it again.") - raise ValueError("Invalid JSON format in schema.") from exc - else: - log.info("Get schema '%s' from graphdb.", self.schema) - return self.schema_manager.run(data_json) + return self.check_schema.run(data_json) + log.info("Get schema '%s' from graphdb.", self.schema) + return self.schema_manager.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_gremlin_example_index.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_gremlin_example_index.py new file mode 100644 index 000000000..8772959d7 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/build_gremlin_example_index.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from PyCGraph import CStatus + +from hugegraph_llm.config import llm_settings +from hugegraph_llm.models.embeddings.init_embedding import get_embedding +from hugegraph_llm.nodes.base_node import BaseNode +from hugegraph_llm.operators.index_op.build_gremlin_example_index import ( + BuildGremlinExampleIndex, +) +from hugegraph_llm.state.ai_state import WkFlowInput, WkFlowState + + +class BuildGremlinExampleIndexNode(BaseNode): + build_gremlin_example_index_op: BuildGremlinExampleIndex + context: WkFlowState = None + wk_input: WkFlowInput = None + + def node_init(self): + if not self.wk_input.examples: + return CStatus(-1, "examples is required in BuildGremlinExampleIndexNode") + examples = self.wk_input.examples + + self.build_gremlin_example_index_op = BuildGremlinExampleIndex( + get_embedding(llm_settings), examples + ) + return super().node_init() + + def operator_schedule(self, data_json): + return self.build_gremlin_example_index_op.run(data_json) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py index bf605aa49..68d2b72f2 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/semantic_id_query_node.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from PyCGraph import CStatus from typing import Dict, Any +from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery from hugegraph_llm.models.embeddings.init_embedding import get_embedding @@ -33,59 +33,61 @@ def node_init(self): """ Initialize the semantic ID query operator. """ - try: - graph_name = huge_settings.graph_name - if not graph_name: - return CStatus(-1, "graph_name is required in wk_input") - - embedding = get_embedding(llm_settings) - by = self.wk_input.semantic_by or "keywords" - topk_per_keyword = ( - self.wk_input.topk_per_keyword or huge_settings.topk_per_keyword - ) - topk_per_query = self.wk_input.topk_per_query or 10 - vector_dis_threshold = ( - self.wk_input.vector_dis_threshold or huge_settings.vector_dis_threshold - ) + graph_name = huge_settings.graph_name + if not graph_name: + return CStatus(-1, "graph_name is required in wk_input") - # Initialize the semantic ID query operator - self.semantic_id_query = SemanticIdQuery( - embedding=embedding, - by=by, - topk_per_keyword=topk_per_keyword, - topk_per_query=topk_per_query, - vector_dis_threshold=vector_dis_threshold, - ) + embedding = get_embedding(llm_settings) + by = ( + self.wk_input.semantic_by + if self.wk_input.semantic_by is not None + else "keywords" + ) + topk_per_keyword = ( + self.wk_input.topk_per_keyword + if self.wk_input.topk_per_keyword is not None + else huge_settings.topk_per_keyword + ) + topk_per_query = ( + self.wk_input.topk_per_query + if self.wk_input.topk_per_query is not None + else 10 + ) + vector_dis_threshold = ( + self.wk_input.vector_dis_threshold + if self.wk_input.vector_dis_threshold is not None + else huge_settings.vector_dis_threshold + ) - return super().node_init() - except Exception as e: - log.error(f"Failed to initialize SemanticIdQueryNode: {e}") + # Initialize the semantic ID query operator + self.semantic_id_query = SemanticIdQuery( + embedding=embedding, + by=by, + topk_per_keyword=topk_per_keyword, + topk_per_query=topk_per_query, + vector_dis_threshold=vector_dis_threshold, + ) - return CStatus(-1, f"SemanticIdQueryNode initialization failed: {e}") + return super().node_init() def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: """ Execute the semantic ID query operation. """ - try: - # Get the query text and keywords from input - query = data_json.get("query", "") - keywords = data_json.get("keywords", []) + # Get the query text and keywords from input + query = data_json.get("query", "") + keywords = data_json.get("keywords", []) - if not query and not keywords: - log.warning("No query text or keywords provided for semantic query") - return data_json - - # Perform the semantic query - semantic_result = self.semantic_id_query.run(data_json) + if not query and not keywords: + log.warning("No query text or keywords provided for semantic query") + return data_json - match_vids = semantic_result.get("match_vids", []) - log.info( - f"Semantic query completed, found {len(match_vids)} matching vertex IDs" - ) + # Perform the semantic query + semantic_result = self.semantic_id_query.run(data_json) - return semantic_result + match_vids = semantic_result.get("match_vids", []) + log.info( + "Semantic query completed, found %d matching vertex IDs", len(match_vids) + ) - except Exception as e: - log.error(f"Semantic query failed: {e}") - return data_json + return semantic_result diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py index 48b50acf3..9c8104c6e 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/index_node/vector_query_node.py @@ -32,20 +32,14 @@ def node_init(self): """ Initialize the vector query operator """ - try: - # 从 wk_input 中读取用户配置参数 - embedding = get_embedding(llm_settings) - max_items = ( - self.wk_input.max_items if self.wk_input.max_items is not None else 3 - ) - - self.operator = VectorIndexQuery(embedding=embedding, topk=max_items) - return super().node_init() - except Exception as e: - log.error(f"Failed to initialize VectorQueryNode: {e}") - from PyCGraph import CStatus + # 从 wk_input 中读取用户配置参数 + embedding = get_embedding(llm_settings) + max_items = ( + self.wk_input.max_items if self.wk_input.max_items is not None else 3 + ) - return CStatus(-1, f"VectorQueryNode initialization failed: {e}") + self.operator = VectorIndexQuery(embedding=embedding, topk=max_items) + return super().node_init() def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: """ @@ -64,11 +58,12 @@ def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: # Update the state data_json.update(result) log.info( - f"Vector query completed, found {len(result.get('vector_result', []))} results" + "Vector query completed, found %d results", + len(result.get("vector_result", [])), ) return data_json - except Exception as e: - log.error(f"Vector query failed: {e}") + except ValueError as e: + log.error("Vector query failed: %s", e) return data_json diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py index 22b970b4a..6997cd781 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/answer_synthesize_node.py @@ -30,70 +30,57 @@ def node_init(self): """ Initialize the answer synthesis operator. """ - try: - prompt_template = self.wk_input.answer_prompt - raw_answer = self.wk_input.raw_answer or False - vector_only_answer = self.wk_input.vector_only_answer or False - graph_only_answer = self.wk_input.graph_only_answer or False - graph_vector_answer = self.wk_input.graph_vector_answer or False + prompt_template = self.wk_input.answer_prompt + raw_answer = self.wk_input.raw_answer or False + vector_only_answer = self.wk_input.vector_only_answer or False + graph_only_answer = self.wk_input.graph_only_answer or False + graph_vector_answer = self.wk_input.graph_vector_answer or False - self.operator = AnswerSynthesize( - prompt_template=prompt_template, - raw_answer=raw_answer, - vector_only_answer=vector_only_answer, - graph_only_answer=graph_only_answer, - graph_vector_answer=graph_vector_answer, - ) - return super().node_init() - except Exception as e: - log.error(f"Failed to initialize AnswerSynthesizeNode: {e}") - from PyCGraph import CStatus - - return CStatus(-1, f"AnswerSynthesizeNode initialization failed: {e}") + self.operator = AnswerSynthesize( + prompt_template=prompt_template, + raw_answer=raw_answer, + vector_only_answer=vector_only_answer, + graph_only_answer=graph_only_answer, + graph_vector_answer=graph_vector_answer, + ) + return super().node_init() def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: """ Execute the answer synthesis operation. """ - try: - if self.getGParamWithNoEmpty("wkflow_input").stream: - # Streaming mode: return a generator for streaming output - data_json["stream_generator"] = self.operator.run_streaming(data_json) - return data_json - else: - # Non-streaming mode: execute answer synthesis - result = self.operator.run(data_json) - - # Record the types of answers generated - answer_types = [] - if result.get("raw_answer"): - answer_types.append("raw") - if result.get("vector_only_answer"): - answer_types.append("vector_only") - if result.get("graph_only_answer"): - answer_types.append("graph_only") - if result.get("graph_vector_answer"): - answer_types.append("graph_vector") + if self.getGParamWithNoEmpty("wkflow_input").stream: + # Streaming mode: return a generator for streaming output + data_json["stream_generator"] = self.operator.run_streaming(data_json) + return data_json + # Non-streaming mode: execute answer synthesis + result = self.operator.run(data_json) - log.info( - f"Answer synthesis completed for types: {', '.join(answer_types)}" - ) + # Record the types of answers generated + answer_types = [] + if result.get("raw_answer"): + answer_types.append("raw") + if result.get("vector_only_answer"): + answer_types.append("vector_only") + if result.get("graph_only_answer"): + answer_types.append("graph_only") + if result.get("graph_vector_answer"): + answer_types.append("graph_vector") - # Print enabled answer types according to self.wk_input configuration - wk_input_types = [] - if getattr(self.wk_input, "raw_answer", False): - wk_input_types.append("raw") - if getattr(self.wk_input, "vector_only_answer", False): - wk_input_types.append("vector_only") - if getattr(self.wk_input, "graph_only_answer", False): - wk_input_types.append("graph_only") - if getattr(self.wk_input, "graph_vector_answer", False): - wk_input_types.append("graph_vector") - log.info( - f"Enabled answer types according to wk_input config: {', '.join(wk_input_types)}" - ) - return result + log.info("Answer synthesis completed for types: %s", ", ".join(answer_types)) - except Exception as e: - log.error(f"Answer synthesis failed: {e}") - return data_json + # Print enabled answer types according to self.wk_input configuration + wk_input_types = [] + if getattr(self.wk_input, "raw_answer", False): + wk_input_types.append("raw") + if getattr(self.wk_input, "vector_only_answer", False): + wk_input_types.append("vector_only") + if getattr(self.wk_input, "graph_only_answer", False): + wk_input_types.append("graph_only") + if getattr(self.wk_input, "graph_vector_answer", False): + wk_input_types.append("graph_vector") + log.info( + "Enabled answer types according to wk_input config: %s", + ", ".join(wk_input_types), + ) + return result diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py index 628765f58..3c9bf2308 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/extract_info.py @@ -48,5 +48,6 @@ def node_init(self): def operator_schedule(self, data_json): if self.extract_type == "triples": return self.info_extract.run(data_json) - elif self.extract_type == "property_graph": + if self.extract_type == "property_graph": return self.property_graph_extract.run(data_json) + raise ValueError("Unsupport extract type") diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py index 76fc06eb3..60542ddc1 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/keyword_extract_node.py @@ -14,7 +14,6 @@ # limitations under the License. from typing import Dict, Any -from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract @@ -32,29 +31,17 @@ def node_init(self): """ Initialize the keyword extraction operator. """ - try: - max_keywords = ( - self.wk_input.max_keywords - if self.wk_input.max_keywords is not None - else 5 - ) - language = ( - self.wk_input.language - if self.wk_input.language is not None - else "english" - ) - extract_template = self.wk_input.keywords_extract_prompt + max_keywords = ( + self.wk_input.max_keywords if self.wk_input.max_keywords is not None else 5 + ) + extract_template = self.wk_input.keywords_extract_prompt - self.operator = KeywordExtract( - text=self.wk_input.query, - max_keywords=max_keywords, - language=language, - extract_template=extract_template, - ) - return super().node_init() - except Exception as e: - log.error(f"Failed to initialize KeywordExtractNode: {e}") - return CStatus(-1, f"KeywordExtractNode initialization failed: {e}") + self.operator = KeywordExtract( + text=self.wk_input.query, + max_keywords=max_keywords, + extract_template=extract_template, + ) + return super().node_init() def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: """ @@ -67,12 +54,12 @@ def operator_schedule(self, data_json: Dict[str, Any]) -> Dict[str, Any]: log.warning("Keyword extraction result missing 'keywords' field") result["keywords"] = [] - log.info(f"Extracted keywords: {result.get('keywords', [])}") + log.info("Extracted keywords: %s", result.get("keywords", [])) return result - except Exception as e: - log.error(f"Keyword extraction failed: {e}") + except ValueError as e: + log.error("Keyword extraction failed: %s", e) # Add error flag to indicate failure error_result = data_json.copy() error_result["error"] = str(e) diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py index 408adb10a..1ef7e5c55 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/schema_build.py @@ -61,12 +61,16 @@ def node_init(self): # few_shot_schema: already parsed dict or raw JSON string few_shot_schema = {} - fss_src = self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None + fss_src = ( + self.wk_input.few_shot_schema if self.wk_input.few_shot_schema else None + ) if fss_src: try: few_shot_schema = json.loads(fss_src) except json.JSONDecodeError as e: - return CStatus(-1, f"Few Shot Schema is not in a valid JSON format: {e}") + return CStatus( + -1, f"Few Shot Schema is not in a valid JSON format: {e}" + ) _context_payload = { "raw_texts": raw_texts, @@ -82,6 +86,6 @@ def operator_schedule(self, data_json): schema_result = self.schema_builder.run(data_json) return {"schema": schema_result} - except Exception as e: + except (ValueError, RuntimeError) as e: log.error("Failed to generate schema: %s", e) return {"schema": f"Schema generation failed: {e}"} diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py index a36831526..0904b9920 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/llm_node/text2gremlin.py @@ -18,7 +18,6 @@ import json from typing import Any, Dict, Optional -from PyCGraph import CStatus from hugegraph_llm.nodes.base_node import BaseNode from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize @@ -27,13 +26,14 @@ def _stable_schema_string(state_json: Dict[str, Any]) -> str: - if "simple_schema" in state_json and state_json["simple_schema"] is not None: - return json.dumps( - state_json["simple_schema"], ensure_ascii=False, sort_keys=True - ) - if "schema" in state_json and state_json["schema"] is not None: - return json.dumps(state_json["schema"], ensure_ascii=False, sort_keys=True) - return "" + val = state_json.get("simple_schema") + if val is None: + val = state_json.get("schema") + if val is None: + return "" + if isinstance(val, str): + return val + return json.dumps(val, ensure_ascii=False, sort_keys=True) class Text2GremlinNode(BaseNode): @@ -56,7 +56,7 @@ def node_init(self): vertices=None, gremlin_prompt=gremlin_prompt, ) - return CStatus() + return super().node_init() def operator_schedule(self, data_json: Dict[str, Any]): # Ensure query exists in context; return empty if not provided diff --git a/hugegraph-llm/src/hugegraph_llm/nodes/util.py b/hugegraph-llm/src/hugegraph_llm/nodes/util.py index 60bdc2e86..d1ac69657 100644 --- a/hugegraph-llm/src/hugegraph_llm/nodes/util.py +++ b/hugegraph-llm/src/hugegraph_llm/nodes/util.py @@ -13,15 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any + from PyCGraph import CStatus -def init_context(obj) -> CStatus: - try: - obj.context = obj.getGParamWithNoEmpty("wkflow_state") - obj.wk_input = obj.getGParamWithNoEmpty("wkflow_input") - if obj.context is None or obj.wk_input is None: - return CStatus(-1, "Required workflow parameters not found") - return CStatus() - except Exception as e: - return CStatus(-1, f"Failed to initialize context: {str(e)}") +def init_context(obj: Any) -> CStatus: + """ + Initialize workflow context for a node. + + Retrieves wkflow_state and wkflow_input from obj's global parameters + and assigns them to obj.context and obj.wk_input respectively. + + Args: + obj: Node object with getGParamWithNoEmpty method + + Returns: + CStatus: Empty status on success, error status with code -1 on failure + """ + obj.context = obj.getGParamWithNoEmpty("wkflow_state") + obj.wk_input = obj.getGParamWithNoEmpty("wkflow_input") + if obj.context is None or obj.wk_input is None: + return CStatus(-1, "Required workflow parameters not found") + return CStatus() diff --git a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py b/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py deleted file mode 100644 index 70f3d27d2..000000000 --- a/hugegraph-llm/src/hugegraph_llm/operators/gremlin_generate_task.py +++ /dev/null @@ -1,81 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import Optional, List - -from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.models.llms.base import BaseLLM -from hugegraph_llm.operators.common_op.check_schema import CheckSchema -from hugegraph_llm.operators.common_op.print_result import PrintResult -from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager -from hugegraph_llm.operators.index_op.build_gremlin_example_index import BuildGremlinExampleIndex -from hugegraph_llm.operators.index_op.gremlin_example_index_query import GremlinExampleIndexQuery -from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize -from hugegraph_llm.utils.decorators import log_time, log_operator_time, record_rpm - - -class GremlinGenerator: - def __init__(self, llm: BaseLLM, embedding: BaseEmbedding): - self.embedding = [] - self.llm = llm - self.embedding = embedding - self.result = None - self.operators = [] - - def clear(self): - self.operators = [] - return self - - def example_index_build(self, examples): - self.operators.append(BuildGremlinExampleIndex(self.embedding, examples)) - return self - - def import_schema(self, from_hugegraph=None, from_extraction=None, from_user_defined=None): - if from_hugegraph: - self.operators.append(SchemaManager(from_hugegraph)) - elif from_user_defined: - self.operators.append(CheckSchema(from_user_defined)) - elif from_extraction: - raise NotImplementedError("Not implemented yet") - else: - raise ValueError("No input data / invalid schema type") - return self - - def example_index_query(self, num_examples): - self.operators.append(GremlinExampleIndexQuery(self.embedding, num_examples)) - return self - - def gremlin_generate_synthesize( - self, schema, gremlin_prompt: Optional[str] = None, vertices: Optional[List[str]] = None - ): - self.operators.append(GremlinGenerateSynthesize(self.llm, schema, vertices, gremlin_prompt)) - return self - - def print_result(self): - self.operators.append(PrintResult()) - return self - - @log_time("total time") - @record_rpm - def run(self, **kwargs): - context = kwargs - for operator in self.operators: - context = self._run_operator(operator, context) - return context - - @log_operator_time - def _run_operator(self, operator, context): - return operator.run(context) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py index 52626b72b..ba4392f7c 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/commit_to_hugegraph.py @@ -40,7 +40,6 @@ def run(self, data: dict) -> Dict[str, Any]: schema = data.get("schema") vertices = data.get("vertices", []) edges = data.get("edges", []) - print(f"get schema {schema}") if not vertices and not edges: log.critical( "(Loading) Both vertices and edges are empty. Please check the input data again." @@ -50,7 +49,9 @@ def run(self, data: dict) -> Dict[str, Any]: if not schema: # TODO: ensure the function works correctly (update the logic later) self.schema_free_mode(data.get("triples", [])) - log.warning("Using schema_free mode, could try schema_define mode for better effect!") + log.warning( + "Using schema_free mode, could try schema_define mode for better effect!" + ) else: self.init_schema_if_need(schema) self.load_into_graph(vertices, edges, schema) @@ -66,7 +67,9 @@ def _set_default_property(self, key, input_properties, property_label_map): # list or set default_value = [] input_properties[key] = default_value - log.warning("Property '%s' missing in vertex, set to '%s' for now", key, default_value) + log.warning( + "Property '%s' missing in vertex, set to '%s' for now", key, default_value + ) def _handle_graph_creation(self, func, *args, **kwargs): try: @@ -80,9 +83,13 @@ def _handle_graph_creation(self, func, *args, **kwargs): def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many-statements # pylint: disable=R0912 (too-many-branches) - vertex_label_map = {v_label["name"]: v_label for v_label in schema["vertexlabels"]} + vertex_label_map = { + v_label["name"]: v_label for v_label in schema["vertexlabels"] + } edge_label_map = {e_label["name"]: e_label for e_label in schema["edgelabels"]} - property_label_map = {p_label["name"]: p_label for p_label in schema["propertykeys"]} + property_label_map = { + p_label["name"]: p_label for p_label in schema["propertykeys"] + } for vertex in vertices: input_label = vertex["label"] @@ -98,7 +105,9 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- vertex_label = vertex_label_map[input_label] primary_keys = vertex_label["primary_keys"] nullable_keys = vertex_label.get("nullable_keys", []) - non_null_keys = [key for key in vertex_label["properties"] if key not in nullable_keys] + non_null_keys = [ + key for key in vertex_label["properties"] if key not in nullable_keys + ] has_problem = False # 2. Handle primary-keys mode vertex @@ -130,7 +139,9 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- # 3. Ensure all non-nullable props are set for key in non_null_keys: if key not in input_properties: - self._set_default_property(key, input_properties, property_label_map) + self._set_default_property( + key, input_properties, property_label_map + ) # 4. Check all data type value is right for key, value in input_properties.items(): @@ -167,7 +178,9 @@ def load_into_graph(self, vertices, edges, schema): # pylint: disable=too-many- continue # TODO: we could try batch add edges first, setback to single-mode if failed - self._handle_graph_creation(self.client.graph().addEdge, label, start, end, properties) + self._handle_graph_creation( + self.client.graph().addEdge, label, start, end, properties + ) def init_schema_if_need(self, schema: dict): properties = schema["propertykeys"] @@ -191,18 +204,20 @@ def init_schema_if_need(self, schema: dict): source_vertex_label = edge["source_label"] target_vertex_label = edge["target_label"] properties = edge["properties"] - self.schema.edgeLabel(edge_label).sourceLabel(source_vertex_label).targetLabel( - target_vertex_label - ).properties(*properties).nullableKeys(*properties).ifNotExist().create() + self.schema.edgeLabel(edge_label).sourceLabel( + source_vertex_label + ).targetLabel(target_vertex_label).properties(*properties).nullableKeys( + *properties + ).ifNotExist().create() def schema_free_mode(self, data): self.schema.propertyKey("name").asText().ifNotExist().create() self.schema.vertexLabel("vertex").useCustomizeStringId().properties( "name" ).ifNotExist().create() - self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel("vertex").properties( - "name" - ).ifNotExist().create() + self.schema.edgeLabel("edge").sourceLabel("vertex").targetLabel( + "vertex" + ).properties("name").ifNotExist().create() self.schema.indexLabel("vertexByName").onV("vertex").by( "name" @@ -262,7 +277,9 @@ def _set_property_data_type(self, property_key, data_type): log.warning("UUID type is not supported, use text instead") property_key.asText() else: - log.error("Unknown data type %s for property_key %s", data_type, property_key) + log.error( + "Unknown data type %s for property_key %s", data_type, property_key + ) def _set_property_cardinality(self, property_key, cardinality): if cardinality == PropertyCardinality.SINGLE: @@ -272,9 +289,13 @@ def _set_property_cardinality(self, property_key, cardinality): elif cardinality == PropertyCardinality.SET: property_key.valueSet() else: - log.error("Unknown cardinality %s for property_key %s", cardinality, property_key) + log.error( + "Unknown cardinality %s for property_key %s", cardinality, property_key + ) - def _check_property_data_type(self, data_type: str, cardinality: str, value) -> bool: + def _check_property_data_type( + self, data_type: str, cardinality: str, value + ) -> bool: if cardinality in ( PropertyCardinality.LIST.value, PropertyCardinality.SET.value, @@ -304,7 +325,9 @@ def _check_single_data_type(self, data_type: str, value) -> bool: if data_type in (PropertyDataType.TEXT.value, PropertyDataType.UUID.value): return isinstance(value, str) # TODO: check ok below - if data_type == PropertyDataType.DATE.value: # the format should be "yyyy-MM-dd" + if ( + data_type == PropertyDataType.DATE.value + ): # the format should be "yyyy-MM-dd" import re return isinstance(value, str) and re.match(r"^\d{4}-\d{2}-\d{2}$", value) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py b/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py deleted file mode 100644 index bcff5f07b..000000000 --- a/hugegraph-llm/src/hugegraph_llm/operators/hugegraph_op/graph_rag_query.py +++ /dev/null @@ -1,455 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import json -from typing import Any, Dict, Optional, List, Set, Tuple - -from hugegraph_llm.config import huge_settings, prompt -from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.models.llms.base import BaseLLM -from hugegraph_llm.operators.gremlin_generate_task import GremlinGenerator -from hugegraph_llm.utils.log import log -from pyhugegraph.client import PyHugeClient - -# TODO: remove 'as('subj)' step -VERTEX_QUERY_TPL = "g.V({keywords}).limit(8).as('subj').toList()" - -# TODO: we could use a simpler query (like kneighbor-api to get the edges) -# TODO: test with profile()/explain() to speed up the query -VID_QUERY_NEIGHBOR_TPL = """\ -g.V({keywords}) -.repeat( - bothE({edge_labels}).limit({edge_limit}).otherV().dedup() -).times({max_deep}).emit() -.simplePath() -.path() -.by(project('label', 'id', 'props') - .by(label()) - .by(id()) - .by(valueMap().by(unfold())) -) -.by(project('label', 'inV', 'outV', 'props') - .by(label()) - .by(inV().id()) - .by(outV().id()) - .by(valueMap().by(unfold())) -) -.limit({max_items}) -.toList() -""" - -PROPERTY_QUERY_NEIGHBOR_TPL = """\ -g.V().has('{prop}', within({keywords})) -.repeat( - bothE({edge_labels}).limit({edge_limit}).otherV().dedup() -).times({max_deep}).emit() -.simplePath() -.path() -.by(project('label', 'props') - .by(label()) - .by(valueMap().by(unfold())) -) -.by(project('label', 'inV', 'outV', 'props') - .by(label()) - .by(inV().values('{prop}')) - .by(outV().values('{prop}')) - .by(valueMap().by(unfold())) -) -.limit({max_items}) -.toList() -""" - - -class GraphRAGQuery: - def __init__( - self, - max_deep: int = 2, - max_graph_items: int = huge_settings.max_graph_items, - prop_to_match: Optional[str] = None, - llm: Optional[BaseLLM] = None, - embedding: Optional[BaseEmbedding] = None, - max_v_prop_len: Optional[int] = 2048, - max_e_prop_len: Optional[int] = 256, - num_gremlin_generate_example: Optional[int] = -1, - gremlin_prompt: Optional[str] = None, - ): - self._client = PyHugeClient( - url=huge_settings.graph_url, - graph=huge_settings.graph_name, - user=huge_settings.graph_user, - pwd=huge_settings.graph_pwd, - graphspace=huge_settings.graph_space, - ) - self._max_deep = max_deep - self._max_items = max_graph_items - self._prop_to_match = prop_to_match - self._schema = "" - self._limit_property = huge_settings.limit_property.lower() == "true" - self._max_v_prop_len = max_v_prop_len - self._max_e_prop_len = max_e_prop_len - self._gremlin_generator = GremlinGenerator( - llm=llm, - embedding=embedding, - ) - self._num_gremlin_generate_example = num_gremlin_generate_example - self._gremlin_prompt = gremlin_prompt or prompt.gremlin_generate_prompt - - def run(self, context: Dict[str, Any]) -> Dict[str, Any]: - self.init_client(context) - - # initial flag: -1 means no result, 0 means subgraph query, 1 means gremlin query - context["graph_result_flag"] = -1 - # 1. Try to perform a query based on the generated gremlin - if self._num_gremlin_generate_example >= 0: - context = self._gremlin_generate_query(context) - # 2. Try to perform a query based on subgraph-search if the previous query failed - if not context.get("graph_result"): - context = self._subgraph_query(context) - - if context.get("graph_result"): - log.debug("Knowledge from Graph:\n%s", "\n".join(context["graph_result"])) - else: - log.debug("No Knowledge Extracted from Graph") - return context - - def _gremlin_generate_query(self, context: Dict[str, Any]) -> Dict[str, Any]: - query = context["query"] - vertices = context.get("match_vids") - query_embedding = context.get("query_embedding") - - self._gremlin_generator.clear() - self._gremlin_generator.example_index_query(num_examples=self._num_gremlin_generate_example) - gremlin_response = self._gremlin_generator.gremlin_generate_synthesize( - context["simple_schema"], vertices=vertices, gremlin_prompt=self._gremlin_prompt - ).run(query=query, query_embedding=query_embedding) - if self._num_gremlin_generate_example > 0: - gremlin = gremlin_response["result"] - else: - gremlin = gremlin_response["raw_result"] - log.info("Generated gremlin: %s", gremlin) - context["gremlin"] = gremlin - try: - result = self._client.gremlin().exec(gremlin=gremlin)["data"] - if result == [None]: - result = [] - context["graph_result"] = [json.dumps(item, ensure_ascii=False) for item in result] - if context["graph_result"]: - context["graph_result_flag"] = 1 - context["graph_context_head"] = ( - f"The following are graph query result " f"from gremlin query `{gremlin}`.\n" - ) - except Exception as e: # pylint: disable=broad-except - log.error(e) - context["graph_result"] = "" - return context - - def _subgraph_query(self, context: Dict[str, Any]) -> Dict[str, Any]: - # 1. Extract params from context - matched_vids = context.get("match_vids") - if isinstance(context.get("max_deep"), int): - self._max_deep = context["max_deep"] - if isinstance(context.get("max_items"), int): - self._max_items = context["max_items"] - if isinstance(context.get("prop_to_match"), str): - self._prop_to_match = context["prop_to_match"] - - # 2. Extract edge_labels from graph schema - _, edge_labels = self._extract_labels_from_schema() - edge_labels_str = ",".join("'" + label + "'" for label in edge_labels) - # TODO: enhance the limit logic later - edge_limit_amount = len(edge_labels) * huge_settings.edge_limit_pre_label - - use_id_to_match = self._prop_to_match is None - if use_id_to_match: - if not matched_vids: - return context - - gremlin_query = VERTEX_QUERY_TPL.format(keywords=matched_vids) - vertexes = self._client.gremlin().exec(gremlin=gremlin_query)["data"] - log.debug("Vids gremlin query: %s", gremlin_query) - - vertex_knowledge = self._format_graph_from_vertex(query_result=vertexes) - paths: List[Any] = [] - # TODO: use generator or asyncio to speed up the query logic - for matched_vid in matched_vids: - gremlin_query = VID_QUERY_NEIGHBOR_TPL.format( - keywords=f"'{matched_vid}'", - max_deep=self._max_deep, - edge_labels=edge_labels_str, - edge_limit=edge_limit_amount, - max_items=self._max_items, - ) - log.debug("Kneighbor gremlin query: %s", gremlin_query) - paths.extend(self._client.gremlin().exec(gremlin=gremlin_query)["data"]) - - graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = ( - self._format_graph_query_result(query_paths=paths) - ) - - # TODO: we may need to optimize the logic here with global deduplication (may lack some single vertex) - if not graph_chain_knowledge: - graph_chain_knowledge.update(vertex_knowledge) - if vertex_degree_list: - vertex_degree_list[0].update(vertex_knowledge) - else: - vertex_degree_list.append(vertex_knowledge) - else: - # WARN: When will the query enter here? - keywords = context.get("keywords") - assert keywords, "No related property(keywords) for graph query." - keywords_str = ",".join("'" + kw + "'" for kw in keywords) - gremlin_query = PROPERTY_QUERY_NEIGHBOR_TPL.format( - prop=self._prop_to_match, - keywords=keywords_str, - edge_labels=edge_labels_str, - edge_limit=edge_limit_amount, - max_deep=self._max_deep, - max_items=self._max_items, - ) - log.warning( - "Unable to find vid, downgraded to property query, please confirm if it meets expectation." - ) - - paths: List[Any] = self._client.gremlin().exec(gremlin=gremlin_query)["data"] - graph_chain_knowledge, vertex_degree_list, knowledge_with_degree = ( - self._format_graph_query_result(query_paths=paths) - ) - - context["graph_result"] = list(graph_chain_knowledge) - if context["graph_result"]: - context["graph_result_flag"] = 0 - context["vertex_degree_list"] = [ - list(vertex_degree) for vertex_degree in vertex_degree_list - ] - context["knowledge_with_degree"] = knowledge_with_degree - context["graph_context_head"] = ( - f"The following are graph knowledge in {self._max_deep} depth, e.g:\n" - "`vertexA--[links]-->vertexB<--[links]--vertexC ...`" - "extracted based on key entities as subject:\n" - ) - return context - - # TODO: move this method to a util file for reuse (remove self param) - def init_client(self, context): - """Initialize the HugeGraph client from context or default settings.""" - # pylint: disable=R0915 (too-many-statements) - if self._client is None: - if isinstance(context.get("graph_client"), PyHugeClient): - self._client = context["graph_client"] - else: - url = context.get("url") or "http://localhost:8080" - graph = context.get("graph") or "hugegraph" - user = context.get("user") or "admin" - pwd = context.get("pwd") or "admin" - gs = context.get("graphspace") or None - self._client = PyHugeClient(url, graph, user, pwd, gs) - assert self._client is not None, "No valid graph to search." - - def get_vertex_details(self, vertex_ids: List[str]) -> List[Dict[str, Any]]: - if not vertex_ids: - return [] - - formatted_ids = ", ".join(f"'{vid}'" for vid in vertex_ids) - gremlin_query = f"g.V({formatted_ids}).limit(20)" - result = self._client.gremlin().exec(gremlin=gremlin_query)["data"] - return result - - def _format_graph_from_vertex(self, query_result: List[Any]) -> Set[str]: - knowledge = set() - for item in query_result: - props_str = ", ".join(f"{k}: {v}" for k, v in item["properties"].items()) - node_str = f"{item['id']}{{{props_str}}}" - knowledge.add(node_str) - return knowledge - - def _format_graph_query_result( - self, query_paths - ) -> Tuple[Set[str], List[Set[str]], Dict[str, List[str]]]: - use_id_to_match = self._prop_to_match is None - subgraph = set() - subgraph_with_degree = {} - vertex_degree_list: List[Set[str]] = [] - v_cache: Set[str] = set() - e_cache: Set[Tuple[str, str, str]] = set() - - for path in query_paths: - # 1. Process each path - path_str, vertex_with_degree = self._process_path( - path, use_id_to_match, v_cache, e_cache - ) - subgraph.add(path_str) - subgraph_with_degree[path_str] = vertex_with_degree - # 2. Update vertex degree list - self._update_vertex_degree_list(vertex_degree_list, vertex_with_degree) - - return subgraph, vertex_degree_list, subgraph_with_degree - - def _process_path( - self, - path: Any, - use_id_to_match: bool, - v_cache: Set[str], - e_cache: Set[Tuple[str, str, str]], - ) -> Tuple[str, List[str]]: - flat_rel = "" - raw_flat_rel = path["objects"] - assert len(raw_flat_rel) % 2 == 1, "The length of raw_flat_rel should be odd." - - node_cache = set() - prior_edge_str_len = 0 - depth = 0 - nodes_with_degree = [] - - for i, item in enumerate(raw_flat_rel): - if i % 2 == 0: - # Process each vertex - flat_rel, prior_edge_str_len, depth = self._process_vertex( - item, - flat_rel, - node_cache, - prior_edge_str_len, - depth, - nodes_with_degree, - use_id_to_match, - v_cache, - ) - else: - # Process each edge - flat_rel, prior_edge_str_len = self._process_edge( - item, flat_rel, raw_flat_rel, i, use_id_to_match, e_cache - ) - - return flat_rel, nodes_with_degree - - def _process_vertex( - self, - item: Any, - flat_rel: str, - node_cache: Set[str], - prior_edge_str_len: int, - depth: int, - nodes_with_degree: List[str], - use_id_to_match: bool, - v_cache: Set[str], - ) -> Tuple[str, int, int]: - matched_str = item["id"] if use_id_to_match else item["props"][self._prop_to_match] - if matched_str in node_cache: - flat_rel = flat_rel[:-prior_edge_str_len] - return flat_rel, prior_edge_str_len, depth - - node_cache.add(matched_str) - props_str = ", ".join( - f"{k}: {self._limit_property_query(v, 'v')}" for k, v in item["props"].items() if v - ) - - # TODO: we may remove label id or replace with label name - if matched_str in v_cache: - node_str = matched_str - else: - v_cache.add(matched_str) - node_str = f"{item['id']}{{{props_str}}}" - - flat_rel += node_str - nodes_with_degree.append(node_str) - depth += 1 - return flat_rel, prior_edge_str_len, depth - - def _process_edge( - self, - item: Any, - path_str: str, - raw_flat_rel: List[Any], - i: int, - use_id_to_match: bool, - e_cache: Set[Tuple[str, str, str]], - ) -> Tuple[str, int]: - props_str = ", ".join( - f"{k}: {self._limit_property_query(v, 'e')}" for k, v in item["props"].items() if v - ) - props_str = f"{{{props_str}}}" if props_str else "" - prev_matched_str = ( - raw_flat_rel[i - 1]["id"] - if use_id_to_match - else (raw_flat_rel)[i - 1]["props"][self._prop_to_match] - ) - - edge_key = (item["inV"], item["label"], item["outV"]) - if edge_key not in e_cache: - e_cache.add(edge_key) - edge_label = f"{item['label']}{props_str}" - else: - edge_label = item["label"] - - edge_str = ( - f"--[{edge_label}]-->" if item["outV"] == prev_matched_str else f"<--[{edge_label}]--" - ) - path_str += edge_str - prior_edge_str_len = len(edge_str) - return path_str, prior_edge_str_len - - def _update_vertex_degree_list( - self, vertex_degree_list: List[Set[str]], nodes_with_degree: List[str] - ) -> None: - for depth, node_str in enumerate(nodes_with_degree): - if depth >= len(vertex_degree_list): - vertex_degree_list.append(set()) - vertex_degree_list[depth].add(node_str) - - def _extract_labels_from_schema(self) -> Tuple[List[str], List[str]]: - schema = self._get_graph_schema() - vertex_props_str, edge_props_str = schema.split("\n")[:2] - # TODO: rename to vertex (also need update in the schema) - vertex_props_str = vertex_props_str[len("Vertex properties: ") :].strip("[").strip("]") - edge_props_str = edge_props_str[len("Edge properties: ") :].strip("[").strip("]") - vertex_labels = self._extract_label_names(vertex_props_str) - edge_labels = self._extract_label_names(edge_props_str) - return vertex_labels, edge_labels - - @staticmethod - def _extract_label_names(source: str, head: str = "name: ", tail: str = ", ") -> List[str]: - result = [] - for s in source.split(head): - end = s.find(tail) - label = s[:end] - if label: - result.append(label) - return result - - def _get_graph_schema(self, refresh: bool = False) -> str: - if self._schema and not refresh: - return self._schema - - schema = self._client.schema() - vertex_schema = schema.getVertexLabels() - edge_schema = schema.getEdgeLabels() - relationships = schema.getRelations() - - self._schema = ( - f"Vertex properties: {vertex_schema}\n" - f"Edge properties: {edge_schema}\n" - f"Relationships: {relationships}\n" - ) - log.debug("Link(Relation): %s", relationships) - return self._schema - - def _limit_property_query(self, value: Optional[str], item_type: str) -> Optional[str]: - # NOTE: we skip the filter for list/set type (e.g., list of string, add it if needed) - if not self._limit_property or not isinstance(value, str): - return value - - max_len = self._max_v_prop_len if item_type == "v" else self._max_e_prop_len - return value[:max_len] if value else value diff --git a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py index 5689a59ac..2ed4e840a 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/index_op/build_semantic_index.py @@ -37,11 +37,15 @@ def __init__(self, embedding: BaseEmbedding): self.folder_name = get_index_folder_name( huge_settings.graph_name, huge_settings.graph_space ) - self.index_dir = str(os.path.join(resource_path, self.folder_name, "graph_vids")) + self.index_dir = str( + os.path.join(resource_path, self.folder_name, "graph_vids") + ) self.filename_prefix = get_filename_prefix( llm_settings.embedding_type, getattr(embedding, "model_name", None) ) - self.vid_index = VectorIndex.from_index_file(self.index_dir, self.filename_prefix) + self.vid_index = VectorIndex.from_index_file( + self.index_dir, self.filename_prefix + ) self.embedding = embedding self.sm = SchemaManager(huge_settings.graph_name) @@ -50,19 +54,27 @@ def _extract_names(self, vertices: list[str]) -> list[str]: def run(self, context: Dict[str, Any]) -> Dict[str, Any]: vertexlabels = self.sm.schema.getSchema()["vertexlabels"] - all_pk_flag = all(data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels) + all_pk_flag = bool(vertexlabels) and all( + data.get("id_strategy") == "PRIMARY_KEY" for data in vertexlabels + ) past_vids = self.vid_index.properties # TODO: We should build vid vector index separately, especially when the vertices may be very large - present_vids = context["vertices"] # Warning: data truncated by fetch_graph_data.py + present_vids = context[ + "vertices" + ] # Warning: data truncated by fetch_graph_data.py removed_vids = set(past_vids) - set(present_vids) removed_num = self.vid_index.remove(removed_vids) added_vids = list(set(present_vids) - set(past_vids)) if added_vids: - vids_to_process = self._extract_names(added_vids) if all_pk_flag else added_vids - added_embeddings = asyncio.run(get_embeddings_parallel(self.embedding, vids_to_process)) + vids_to_process = ( + self._extract_names(added_vids) if all_pk_flag else added_vids + ) + added_embeddings = asyncio.run( + get_embeddings_parallel(self.embedding, vids_to_process) + ) log.info("Building vector index for %s vertices...", len(added_vids)) self.vid_index.add(added_embeddings, added_vids) self.vid_index.to_index_file(self.index_dir, self.filename_prefix) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py b/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py deleted file mode 100644 index 3b5c63103..000000000 --- a/hugegraph-llm/src/hugegraph_llm/operators/kg_construction_task.py +++ /dev/null @@ -1,120 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - - -from typing import Dict, Any, Optional, Literal, Union, List - -from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.models.llms.base import BaseLLM -from hugegraph_llm.operators.common_op.check_schema import CheckSchema -from hugegraph_llm.operators.common_op.print_result import PrintResult -from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit -from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph -from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData -from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager -from hugegraph_llm.operators.index_op.build_semantic_index import BuildSemanticIndex -from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndex -from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData -from hugegraph_llm.operators.llm_op.info_extract import InfoExtract -from hugegraph_llm.operators.llm_op.property_graph_extract import PropertyGraphExtract -from hugegraph_llm.operators.llm_op.schema_build import SchemaBuilder -from hugegraph_llm.utils.decorators import log_time, log_operator_time, record_rpm -from pyhugegraph.client import PyHugeClient - - -class KgBuilder: - def __init__( - self, - llm: BaseLLM, - embedding: Optional[BaseEmbedding] = None, - graph: Optional[PyHugeClient] = None, - ): - self.operators = [] - self.llm = llm - self.embedding = embedding - self.graph = graph - self.result = None - - def import_schema(self, from_hugegraph=None, from_extraction=None, from_user_defined=None): - if from_hugegraph: - self.operators.append(SchemaManager(from_hugegraph)) - elif from_user_defined: - self.operators.append(CheckSchema(from_user_defined)) - elif from_extraction: - raise NotImplementedError("Not implemented yet") - else: - raise ValueError("No input data / invalid schema type") - return self - - def fetch_graph_data(self): - self.operators.append(FetchGraphData(self.graph)) - return self - - def chunk_split( - self, - text: Union[str, List[str]], # text to be split - split_type: Literal["document", "paragraph", "sentence"] = "document", - language: Literal["zh", "en"] = "zh", - ): - self.operators.append(ChunkSplit(text, split_type, language)) - return self - - def extract_info( - self, - example_prompt: Optional[str] = None, - extract_type: Literal["triples", "property_graph"] = "triples", - ): - if extract_type == "triples": - self.operators.append(InfoExtract(self.llm, example_prompt)) - elif extract_type == "property_graph": - self.operators.append(PropertyGraphExtract(self.llm, example_prompt)) - return self - - def disambiguate_word_sense(self): - self.operators.append(DisambiguateData(self.llm)) - return self - - def commit_to_hugegraph(self): - self.operators.append(Commit2Graph()) - return self - - def build_vertex_id_semantic_index(self): - self.operators.append(BuildSemanticIndex(self.embedding)) - return self - - def build_vector_index(self): - self.operators.append(BuildVectorIndex(self.embedding)) - return self - - def print_result(self): - self.operators.append(PrintResult()) - return self - - def build_schema(self): - self.operators.append(SchemaBuilder(self.llm)) - return self - - @log_time("total time") - @record_rpm - def run(self, context: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - for operator in self.operators: - context = self._run_operator(operator, context) - return context - - @log_operator_time - def _run_operator(self, operator, context) -> Dict[str, Any]: - return operator.run(context) diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py index 32ed9651e..48369b4ec 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py @@ -22,7 +22,9 @@ from hugegraph_llm.config import prompt, llm_settings from hugegraph_llm.models.llms.base import BaseLLM from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.document_op.textrank_word_extract import MultiLingualTextRank +from hugegraph_llm.operators.document_op.textrank_word_extract import ( + MultiLingualTextRank, +) from hugegraph_llm.utils.log import log KEYWORDS_EXTRACT_TPL = prompt.keywords_extract_prompt @@ -43,8 +45,8 @@ def __init__( self._extract_template = extract_template or KEYWORDS_EXTRACT_TPL self._extract_method = llm_settings.keyword_extract_type.lower() self._textrank_model = MultiLingualTextRank( - keyword_num=max_keywords, - window_size=llm_settings.window_size) + keyword_num=max_keywords, window_size=llm_settings.window_size + ) def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if self._query is None: @@ -66,7 +68,11 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: max_keyword_num = self._max_keywords self._max_keywords = max(1, max_keyword_num) - method = (context.get("extract_method", self._extract_method) or "LLM").strip().lower() + method = ( + (context.get("extract_method", self._extract_method) or "LLM") + .strip() + .lower() + ) if method == "llm": # LLM method ranks = self._extract_with_llm() @@ -82,7 +88,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: keywords = [] if not ranks else sorted(ranks, key=ranks.get, reverse=True) keywords = [k.replace("'", "") for k in keywords] - context["keywords"] = keywords[:self._max_keywords] + context["keywords"] = keywords[: self._max_keywords] log.info("User Query: %s\nKeywords: %s", self._query, context["keywords"]) # extracting keywords & expanding synonyms increase the call count by 1 @@ -101,7 +107,7 @@ def _extract_with_llm(self) -> Dict[str, float]: return keywords def _extract_with_textrank(self) -> Dict[str, float]: - """ TextRank mode extraction """ + """TextRank mode extraction""" start_time = time.perf_counter() ranks = {} try: @@ -111,12 +117,13 @@ def _extract_with_textrank(self) -> Dict[str, float]: except MemoryError as e: log.critical("TextRank memory error (text too large?): %s", e) end_time = time.perf_counter() - log.debug("TextRank Keyword extraction time: %.2f seconds", - end_time - start_time) + log.debug( + "TextRank Keyword extraction time: %.2f seconds", end_time - start_time + ) return ranks def _extract_with_hybrid(self) -> Dict[str, float]: - """ Hybrid mode extraction """ + """Hybrid mode extraction""" ranks = {} if isinstance(llm_settings.hybrid_llm_weights, float): @@ -140,7 +147,7 @@ def _extract_with_hybrid(self) -> Dict[str, float]: if word in llm_scores: ranks[word] += llm_scores[word] * llm_weights if word in tr_scores: - ranks[word] += tr_scores[word] * (1-llm_weights) + ranks[word] += tr_scores[word] * (1 - llm_weights) end_time = time.perf_counter() log.debug("Hybrid Keyword extraction time: %.2f seconds", end_time - start_time) @@ -151,13 +158,11 @@ def _extract_keywords_from_response( response: str, lowercase: bool = True, start_token: str = "", -<<<<<<< HEAD ) -> Dict[str, float]: - results = {} # use re.escape(start_token) if start_token contains special chars like */&/^ etc. - matches = re.findall(rf'{start_token}([^\n]+\n?)', response) + matches = re.findall(rf"{start_token}([^\n]+\n?)", response) for match in matches: match = match.strip() @@ -175,34 +180,13 @@ def _extract_keywords_from_response( continue score_val = float(score_raw) if not 0.0 <= score_val <= 1.0: - log.warning("Score out of range for %s: %s", word_raw, score_val) + log.warning( + "Score out of range for %s: %s", word_raw, score_val + ) score_val = min(1.0, max(0.0, score_val)) word_out = word_raw.lower() if lowercase else word_raw results[word_out] = score_val except (ValueError, AttributeError) as e: log.warning("Failed to parse item '%s': %s", item, e) continue -======= - ) -> Set[str]: - keywords = [] - # use re.escape(start_token) if start_token contains special chars like */&/^ etc. - matches = re.findall(rf"{start_token}[^\n]+\n?", response) - - for match in matches: - match = match[len(start_token) :].strip() - keywords.extend( - k.lower() if lowercase else k - for k in re.split(r"[,,]+", match) - if len(k.strip()) > 1 - ) - - # if the keyword consists of multiple words, split into sub-words (removing stopwords) - results = set(keywords) - for token in keywords: - sub_tokens = re.findall(r"\w+", token) - if len(sub_tokens) > 1: - results.update( - w for w in sub_tokens if w not in NLTKHelper().stopwords(lang=self._language) - ) ->>>>>>> 78011d3 (Refactor: text2germlin with PCgraph framework (#50)) return results diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/operator_list.py similarity index 54% rename from hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py rename to hugegraph-llm/src/hugegraph_llm/operators/operator_list.py index 58848f827..6b6bf48e2 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/operator_list.py @@ -14,45 +14,137 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from typing import Optional, List, Literal, Union - -from typing import Any, Dict, List, Literal, Optional - -from hugegraph_llm.config import huge_settings, prompt +from hugegraph_llm.config import huge_settings from hugegraph_llm.models.embeddings.base import BaseEmbedding -from hugegraph_llm.models.embeddings.init_embedding import Embeddings from hugegraph_llm.models.llms.base import BaseLLM -from hugegraph_llm.models.llms.init_llm import LLMs -from hugegraph_llm.operators.common_op.merge_dedup_rerank import MergeDedupRerank +from hugegraph_llm.operators.common_op.check_schema import CheckSchema from hugegraph_llm.operators.common_op.print_result import PrintResult -from hugegraph_llm.operators.document_op.word_extract import WordExtract -from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery from hugegraph_llm.operators.hugegraph_op.schema_manager import SchemaManager +from hugegraph_llm.operators.index_op.build_gremlin_example_index import ( + BuildGremlinExampleIndex, +) +from hugegraph_llm.operators.index_op.gremlin_example_index_query import ( + GremlinExampleIndexQuery, +) +from hugegraph_llm.operators.llm_op.gremlin_generate import GremlinGenerateSynthesize +from hugegraph_llm.utils.decorators import log_time, log_operator_time, record_rpm +from hugegraph_llm.operators.hugegraph_op.fetch_graph_data import FetchGraphData +from hugegraph_llm.operators.document_op.chunk_split import ChunkSplit +from hugegraph_llm.operators.llm_op.info_extract import InfoExtract +from hugegraph_llm.operators.llm_op.property_graph_extract import PropertyGraphExtract +from hugegraph_llm.operators.llm_op.disambiguate_data import DisambiguateData +from hugegraph_llm.operators.hugegraph_op.commit_to_hugegraph import Commit2Graph +from hugegraph_llm.operators.index_op.build_semantic_index import BuildSemanticIndex +from hugegraph_llm.operators.index_op.build_vector_index import BuildVectorIndex +from hugegraph_llm.operators.document_op.word_extract import WordExtract +from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract from hugegraph_llm.operators.index_op.semantic_id_query import SemanticIdQuery from hugegraph_llm.operators.index_op.vector_index_query import VectorIndexQuery +from hugegraph_llm.operators.common_op.merge_dedup_rerank import MergeDedupRerank from hugegraph_llm.operators.llm_op.answer_synthesize import AnswerSynthesize -from hugegraph_llm.operators.llm_op.keyword_extract import KeywordExtract -from hugegraph_llm.utils.decorators import log_operator_time, log_time, record_rpm +from pyhugegraph.client import PyHugeClient + +class OperatorList: + def __init__( + self, + llm: BaseLLM, + embedding: BaseEmbedding, + graph: Optional[PyHugeClient] = None, + ): + self.llm = llm + self.embedding = embedding + self.result = None + self.operators = [] + self.graph = graph -class RAGPipeline: - """ - RAGPipeline is a (core)class that encapsulates a series of operations for extracting information from text, - querying graph databases and vector indices, merging and re-ranking results, and generating answers. - """ + def clear(self): + self.operators = [] + return self - def __init__(self, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedding] = None): - """ - Initialize the RAGPipeline with optional LLM and embedding models. + def example_index_build(self, examples): + self.operators.append(BuildGremlinExampleIndex(self.embedding, examples)) + return self - :param llm: Optional LLM model to use. - :param embedding: Optional embedding model to use. - """ - self._chat_llm = llm or LLMs().get_chat_llm() - self._extract_llm = llm or LLMs().get_extract_llm() - self._text2gqlt_llm = llm or LLMs().get_text2gql_llm() - self._embedding = embedding or Embeddings().get_embedding() - self._operators: List[Any] = [] + def import_schema( + self, from_hugegraph=None, from_extraction=None, from_user_defined=None + ): + if from_hugegraph: + self.operators.append(SchemaManager(from_hugegraph)) + elif from_user_defined: + self.operators.append(CheckSchema(from_user_defined)) + elif from_extraction: + raise NotImplementedError("Not implemented yet") + else: + raise ValueError("No input data / invalid schema type") + return self + + def example_index_query(self, num_examples): + self.operators.append(GremlinExampleIndexQuery(self.embedding, num_examples)) + return self + + def gremlin_generate_synthesize( + self, + schema, + gremlin_prompt: Optional[str] = None, + vertices: Optional[List[str]] = None, + ): + self.operators.append( + GremlinGenerateSynthesize(self.llm, schema, vertices, gremlin_prompt) + ) + return self + + def print_result(self): + self.operators.append(PrintResult()) + return self + + def fetch_graph_data(self): + if self.graph is None: + raise ValueError("graph client is required for fetch_graph_data operation") + self.operators.append(FetchGraphData(self.graph)) + return self + + def chunk_split( + self, + text: Union[str, List[str]], # text to be split + split_type: Literal["document", "paragraph", "sentence"] = "document", + language: Literal["zh", "en"] = "zh", + ): + self.operators.append(ChunkSplit(text, split_type, language)) + return self + + def extract_info( + self, + example_prompt: Optional[str] = None, + extract_type: Literal["triples", "property_graph"] = "triples", + ): + if extract_type == "triples": + self.operators.append(InfoExtract(self.llm, example_prompt)) + elif extract_type == "property_graph": + self.operators.append(PropertyGraphExtract(self.llm, example_prompt)) + else: + raise ValueError( + f"invalid extract_type: {extract_type!r}, expected 'triples' or 'property_graph'" + ) + return self + + def disambiguate_word_sense(self): + self.operators.append(DisambiguateData(self.llm)) + return self + + def commit_to_hugegraph(self): + self.operators.append(Commit2Graph()) + return self + + def build_vertex_id_semantic_index(self): + self.operators.append(BuildSemanticIndex(self.embedding)) + return self + + def build_vector_index(self): + self.operators.append(BuildVectorIndex(self.embedding)) + return self def extract_word(self, text: Optional[str] = None): """ @@ -61,7 +153,7 @@ def extract_word(self, text: Optional[str] = None): :param text: Text to extract words from. :return: Self-instance for chaining. """ - self._operators.append(WordExtract(text=text)) + self.operators.append(WordExtract(text=text)) return self def extract_keywords( @@ -76,18 +168,11 @@ def extract_keywords( :param extract_template: Template for keyword extraction. :return: Self-instance for chaining. """ - self._operators.append( - KeywordExtract( - text=text, - extract_template=extract_template - ) + self.operators.append( + KeywordExtract(text=text, extract_template=extract_template) ) return self - def import_schema(self, graph_name: str): - self._operators.append(SchemaManager(graph_name)) - return self - def keywords_to_vid( self, by: Literal["query", "keywords"] = "keywords", @@ -103,9 +188,9 @@ def keywords_to_vid( :param vector_dis_threshold: Vector distance threshold. :return: Self-instance for chaining. """ - self._operators.append( + self.operators.append( SemanticIdQuery( - embedding=self._embedding, + embedding=self.embedding, by=by, topk_per_keyword=topk_per_keyword, topk_per_query=topk_per_query, @@ -114,41 +199,6 @@ def keywords_to_vid( ) return self - def query_graphdb( - self, - max_deep: int = 2, - max_graph_items: int = huge_settings.max_graph_items, - max_v_prop_len: int = 2048, - max_e_prop_len: int = 256, - prop_to_match: Optional[str] = None, - num_gremlin_generate_example: Optional[int] = -1, - gremlin_prompt: Optional[str] = prompt.gremlin_generate_prompt, - ): - """ - Add a graph RAG query operator to the pipeline. - - :param max_deep: Maximum depth for the graph query. - :param max_graph_items: Maximum number of items to retrieve. - :param max_v_prop_len: Maximum length of vertex properties. - :param max_e_prop_len: Maximum length of edge properties. - :param prop_to_match: Property to match in the graph. - :param num_gremlin_generate_example: Number of examples to generate. - :param gremlin_prompt: Gremlin prompt for generating examples. - :return: Self-instance for chaining. - """ - self._operators.append( - GraphRAGQuery( - max_deep=max_deep, - max_graph_items=max_graph_items, - max_v_prop_len=max_v_prop_len, - max_e_prop_len=max_e_prop_len, - prop_to_match=prop_to_match, - num_gremlin_generate_example=num_gremlin_generate_example, - gremlin_prompt=gremlin_prompt, - ) - ) - return self - def query_vector_index(self, max_items: int = 3): """ Add a vector index query operator to the pipeline. @@ -156,9 +206,9 @@ def query_vector_index(self, max_items: int = 3): :param max_items: Maximum number of items to retrieve. :return: Self-instance for chaining. """ - self._operators.append( + self.operators.append( VectorIndexQuery( - embedding=self._embedding, + embedding=self.embedding, topk=max_items, ) ) @@ -177,9 +227,9 @@ def merge_dedup_rerank( :return: Self-instance for chaining. """ - self._operators.append( + self.operators.append( MergeDedupRerank( - embedding=self._embedding, + embedding=self.embedding, graph_ratio=graph_ratio, method=rerank_method, near_neighbor_first=near_neighbor_first, @@ -207,7 +257,7 @@ def synthesize_answer( :param answer_prompt: Template for the answer synthesis prompt. :return: Self-instance for chaining. """ - self._operators.append( + self.operators.append( AnswerSynthesize( raw_answer=raw_answer, vector_only_answer=vector_only_answer, @@ -218,32 +268,11 @@ def synthesize_answer( ) return self - def print_result(self): - """ - Add a print result operator to the pipeline. - - :return: Self-instance for chaining. - """ - self._operators.append(PrintResult()) - return self - @log_time("total time") @record_rpm - def run(self, **kwargs) -> Dict[str, Any]: - """ - Execute all operators in the pipeline in sequence. - - :param kwargs: Additional context to pass to operators. - :return: Final context after all operators have been executed. - """ - if len(self._operators) == 0: - self.extract_keywords().query_graphdb( - max_graph_items=kwargs.get("max_graph_items") - ).synthesize_answer() - + def run(self, **kwargs): context = kwargs - - for operator in self._operators: + for operator in self.operators: context = self._run_operator(operator, context) return context diff --git a/hugegraph-llm/src/hugegraph_llm/operators/util.py b/hugegraph-llm/src/hugegraph_llm/operators/util.py deleted file mode 100644 index 60bdc2e86..000000000 --- a/hugegraph-llm/src/hugegraph_llm/operators/util.py +++ /dev/null @@ -1,27 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from PyCGraph import CStatus - - -def init_context(obj) -> CStatus: - try: - obj.context = obj.getGParamWithNoEmpty("wkflow_state") - obj.wk_input = obj.getGParamWithNoEmpty("wkflow_input") - if obj.context is None or obj.wk_input is None: - return CStatus(-1, "Required workflow parameters not found") - return CStatus() - except Exception as e: - return CStatus(-1, f"Failed to initialize context: {str(e)}") diff --git a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py index 3a6fd3c1c..429aba955 100644 --- a/hugegraph-llm/src/hugegraph_llm/state/ai_state.py +++ b/hugegraph-llm/src/hugegraph_llm/state/ai_state.py @@ -13,64 +13,71 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import AsyncGenerator, Union, List, Optional, Any, Dict from PyCGraph import GParam, CStatus -from typing import Union, List, Optional, Any +from hugegraph_llm.utils.log import log class WkFlowInput(GParam): - texts: Union[str, List[str]] = None # texts input used by ChunkSplit Node - language: str = None # language configuration used by ChunkSplit Node - split_type: str = None # split type used by ChunkSplit Node - example_prompt: str = None # need by graph information extract - schema: str = None # Schema information requeired by SchemaNode - data_json = None - extract_type = None - query_examples = None - few_shot_schema = None + texts: Optional[Union[str, List[str]]] = None # texts input used by ChunkSplit Node + language: Optional[str] = None # language configuration used by ChunkSplit Node + split_type: Optional[str] = None # split type used by ChunkSplit Node + example_prompt: Optional[str] = None # need by graph information extract + schema: Optional[str] = None # Schema information requeired by SchemaNode + data_json: Optional[Dict[str, Any]] = None + extract_type: Optional[str] = None + query_examples: Optional[Any] = None + few_shot_schema: Optional[Any] = None # Fields related to PromptGenerate - source_text: str = None # Original text - scenario: str = None # Scenario description - example_name: str = None # Example name + source_text: Optional[str] = None # Original text + scenario: Optional[str] = None # Scenario description + example_name: Optional[str] = None # Example name # Fields for Text2Gremlin - example_num: int = None - gremlin_prompt: str = None + example_num: Optional[int] = None requested_outputs: Optional[List[str]] = None # RAG Flow related fields - query: str = None # User query for RAG - vector_search: bool = None # Enable vector search - graph_search: bool = None # Enable graph search - raw_answer: bool = None # Return raw answer - vector_only_answer: bool = None # Vector only answer mode - graph_only_answer: bool = None # Graph only answer mode - graph_vector_answer: bool = None # Combined graph and vector answer - graph_ratio: float = None # Graph ratio for merging - rerank_method: str = None # Reranking method - near_neighbor_first: bool = None # Near neighbor first flag - custom_related_information: str = None # Custom related information - answer_prompt: str = None # Answer generation prompt - keywords_extract_prompt: str = None # Keywords extraction prompt - gremlin_tmpl_num: int = None # Gremlin template number - gremlin_prompt: str = None # Gremlin generation prompt - max_graph_items: int = None # Maximum graph items - topk_return_results: int = None # Top-k return results - vector_dis_threshold: float = None # Vector distance threshold - topk_per_keyword: int = None # Top-k per keyword - max_keywords: int = None - max_items: int = None + query: Optional[str] = None # User query for RAG + vector_search: Optional[bool] = None # Enable vector search + graph_search: Optional[bool] = None # Enable graph search + raw_answer: Optional[bool] = None # Return raw answer + vector_only_answer: Optional[bool] = None # Vector only answer mode + graph_only_answer: Optional[bool] = None # Graph only answer mode + graph_vector_answer: Optional[bool] = None # Combined graph and vector answer + graph_ratio: Optional[float] = None # Graph ratio for merging + rerank_method: Optional[str] = None # Reranking method + near_neighbor_first: Optional[bool] = None # Near neighbor first flag + custom_related_information: Optional[str] = None # Custom related information + answer_prompt: Optional[str] = None # Answer generation prompt + keywords_extract_prompt: Optional[str] = None # Keywords extraction prompt + gremlin_tmpl_num: Optional[int] = None # Gremlin template number + gremlin_prompt: Optional[str] = None # Gremlin generation prompt + max_graph_items: Optional[int] = None # Maximum graph items + topk_return_results: Optional[int] = None # Top-k return results + vector_dis_threshold: Optional[float] = None # Vector distance threshold + topk_per_keyword: Optional[int] = None # Top-k per keyword + max_keywords: Optional[int] = None + max_items: Optional[int] = None # Semantic query related fields - semantic_by: str = None # Semantic query method - topk_per_query: int = None # Top-k per query + semantic_by: Optional[str] = None # Semantic query method + topk_per_query: Optional[int] = None # Top-k per query # Graph query related fields - max_deep: int = None # Maximum depth for graph traversal - max_v_prop_len: int = None # Maximum vertex property length - max_e_prop_len: int = None # Maximum edge property length - prop_to_match: str = None # Property to match + max_deep: Optional[int] = None # Maximum depth for graph traversal + max_v_prop_len: Optional[int] = None # Maximum vertex property length + max_e_prop_len: Optional[int] = None # Maximum edge property length + prop_to_match: Optional[str] = None # Property to match - stream: bool = None # used for recognize stream mode + stream: Optional[bool] = None # used for recognize stream mode + + # used for rag_recall api + is_graph_rag_recall: bool = False + is_vector_only: bool = False + + # used for build text2gremin index + examples: Optional[List[Dict[str, str]]] = None def reset(self, _: CStatus) -> None: self.texts = None @@ -78,7 +85,6 @@ def reset(self, _: CStatus) -> None: self.split_type = None self.example_prompt = None self.schema = None - self.graph_name = None self.data_json = None self.extract_type = None self.query_examples = None @@ -106,7 +112,6 @@ def reset(self, _: CStatus) -> None: self.answer_prompt = None self.keywords_extract_prompt = None self.gremlin_tmpl_num = None - self.gremlin_prompt = None self.max_graph_items = None self.topk_return_results = None self.vector_dis_threshold = None @@ -123,6 +128,10 @@ def reset(self, _: CStatus) -> None: self.prop_to_match = None self.stream = None + self.examples = None + self.is_graph_rag_recall = False + self.is_vector_only = False + class WkFlowState(GParam): schema: Optional[str] = None # schema message @@ -134,9 +143,9 @@ class WkFlowState(GParam): call_count: Optional[int] = None keywords: Optional[List[str]] = None - vector_result = None - graph_result = None - keywords_embeddings = None + vector_result: Optional[Any] = None + graph_result: Optional[Any] = None + keywords_embeddings: Optional[Any] = None generated_extract_prompt: Optional[str] = None # Fields for Text2Gremlin results @@ -146,18 +155,43 @@ class WkFlowState(GParam): template_exec_res: Optional[Any] = None raw_exec_res: Optional[Any] = None - match_vids = None - vector_result = None - graph_result = None + match_vids: Optional[Any] = None + + raw_answer: Optional[str] = None + vector_only_answer: Optional[str] = None + graph_only_answer: Optional[str] = None + graph_vector_answer: Optional[str] = None + + merged_result: Optional[Any] = None + + vertex_num: Optional[int] = None + edge_num: Optional[int] = None + note: Optional[str] = None + removed_vid_vector_num: Optional[int] = None + added_vid_vector_num: Optional[int] = None + raw_texts: Optional[List] = None + query_examples: Optional[List] = None + few_shot_schema: Optional[Dict] = None + source_text: Optional[str] = None + scenario: Optional[str] = None + example_name: Optional[str] = None - raw_answer: str = None - vector_only_answer: str = None - graph_only_answer: str = None - graph_vector_answer: str = None + graph_ratio: Optional[float] = None + query: Optional[str] = None + vector_search: Optional[bool] = None + graph_search: Optional[bool] = None + max_graph_items: Optional[int] = None + stream_generator: Optional[AsyncGenerator] = None - merged_result = None + graph_result_flag: Optional[int] = None + vertex_degree_list: Optional[List] = None + knowledge_with_degree: Optional[Dict] = None + graph_context_head: Optional[str] = None - def setup(self): + embed_dim: Optional[int] = None + is_graph_rag_recall: Optional[bool] = None + + def setup(self) -> CStatus: self.schema = None self.simple_schema = None self.chunks = None @@ -184,9 +218,36 @@ def setup(self): self.graph_only_answer = None self.graph_vector_answer = None - self.vector_result = None - self.graph_result = None self.merged_result = None + + self.match_vids = None + self.vertex_num = None + self.edge_num = None + self.note = None + self.removed_vid_vector_num = None + self.added_vid_vector_num = None + + self.raw_texts = None + self.query_examples = None + self.few_shot_schema = None + self.source_text = None + self.scenario = None + self.example_name = None + + self.graph_ratio = None + self.query = None + self.vector_search = None + self.graph_search = None + self.max_graph_items = None + + self.stream_generator = None + self.graph_result_flag = None + self.vertex_degree_list = None + self.knowledge_with_degree = None + self.graph_context_head = None + + self.embed_dim = None + self.is_graph_rag_recall = None return CStatus() def to_json(self): @@ -210,4 +271,9 @@ def assign_from_json(self, data_json: dict): Assigns each key in the input json object as a member variable of WkFlowState. """ for k, v in data_json.items(): - setattr(self, k, v) + if hasattr(self, k): + setattr(self, k, v) + else: + log.warning( + "key %s should be a member of WkFlowState & type %s", k, type(v) + ) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index 3f527f2fa..9c53e8183 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -16,29 +16,28 @@ # under the License. -import json import os import traceback -from typing import Dict, Any, Union, Optional +from typing import Dict, Any, Union, List import gradio as gr +from hugegraph_llm.flows import FlowName from hugegraph_llm.flows.scheduler import SchedulerSingleton +from pyhugegraph.client import PyHugeClient from .embedding_utils import get_filename_prefix, get_index_folder_name -from .hugegraph_utils import get_hg_client, clean_hg_data +from .hugegraph_utils import clean_hg_data from .log import log from .vector_index_utils import read_documents from ..config import resource_path, huge_settings, llm_settings from ..indices.vector_index import VectorIndex from ..models.embeddings.init_embedding import Embeddings -from ..models.llms.init_llm import LLMs -from ..operators.kg_construction_task import KgBuilder def get_graph_index_info(): try: scheduler = SchedulerSingleton.get_instance() - return scheduler.schedule_flow("get_graph_index_info") + return scheduler.schedule_flow(FlowName.GET_GRAPH_INDEX_INFO) except Exception as e: # pylint: disable=broad-exception-caught log.error(e) raise gr.Error(str(e)) @@ -63,63 +62,31 @@ def clean_all_graph_index(): gr.Info("Clear graph index and text2gql index successfully!") -def clean_all_graph_data(): - clean_hg_data() - log.warning("Clear graph data successfully!") - gr.Info("Clear graph data successfully!") - - -def parse_schema(schema: str, builder: KgBuilder) -> Optional[str]: - schema = schema.strip() - if schema.startswith("{"): - try: - schema = json.loads(schema) - builder.import_schema(from_user_defined=schema) - except json.JSONDecodeError: - log.error("Invalid JSON format in schema. Please check it again.") - return "ERROR: Invalid JSON format in schema. Please check it carefully." +def get_vertex_details( + vertex_ids: List[str], context: Dict[str, Any] +) -> List[Dict[str, Any]]: + if isinstance(context.get("graph_client"), PyHugeClient): + client = context["graph_client"] else: - log.info("Get schema '%s' from graphdb.", schema) - builder.import_schema(from_hugegraph=schema) - return None + url = context.get("url") or "http://localhost:8080" + graph = context.get("graph") or "hugegraph" + user = context.get("user") or "admin" + pwd = context.get("pwd") or "admin" + gs = context.get("graphspace") or None + client = PyHugeClient(url, graph, user, pwd, gs) + if not vertex_ids: + return [] + formatted_ids = ", ".join(f"'{vid}'" for vid in vertex_ids) + gremlin_query = f"g.V({formatted_ids}).limit(20)" + result = client.gremlin().exec(gremlin=gremlin_query)["data"] + return result -def extract_graph_origin(input_file, input_text, schema, example_prompt) -> str: - texts = read_documents(input_file, input_text) - builder = KgBuilder( - LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client() - ) - if not schema: - return "ERROR: please input with correct schema/format." - - error_message = parse_schema(schema, builder) - if error_message: - return error_message - builder.chunk_split(texts, "document", "zh").extract_info( - example_prompt, "property_graph" - ) - try: - context = builder.run() - if not context["vertices"] and not context["edges"]: - log.info("Please check the schema.(The schema may not match the Doc)") - return json.dumps( - { - "vertices": context["vertices"], - "edges": context["edges"], - "warning": "The schema may not match the Doc", - }, - ensure_ascii=False, - indent=2, - ) - return json.dumps( - {"vertices": context["vertices"], "edges": context["edges"]}, - ensure_ascii=False, - indent=2, - ) - except Exception as e: # pylint: disable=broad-exception-caught - log.error(e) - raise gr.Error(str(e)) +def clean_all_graph_data(): + clean_hg_data() + log.warning("Clear graph data successfully!") + gr.Info("Clear graph data successfully!") def extract_graph(input_file, input_text, schema, example_prompt) -> str: @@ -130,7 +97,7 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: try: return scheduler.schedule_flow( - "graph_extract", schema, texts, example_prompt, "property_graph" + FlowName.GRAPH_EXTRACT, schema, texts, example_prompt, "property_graph" ) except Exception as e: # pylint: disable=broad-exception-caught log.error(e) @@ -140,7 +107,7 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: def update_vid_embedding(): scheduler = SchedulerSingleton.get_instance() try: - return scheduler.schedule_flow("update_vid_embeddings") + return scheduler.schedule_flow(FlowName.UPDATE_VID_EMBEDDINGS) except Exception as e: # pylint: disable=broad-exception-caught log.error(e) raise gr.Error(str(e)) @@ -149,7 +116,7 @@ def update_vid_embedding(): def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: scheduler = SchedulerSingleton.get_instance() - return scheduler.schedule_flow("import_graph_data", data, schema) + return scheduler.schedule_flow(FlowName.IMPORT_GRAPH_DATA, data, schema) except Exception as e: # pylint: disable=W0718 log.error(e) traceback.print_exc() @@ -162,7 +129,8 @@ def build_schema(input_text, query_example, few_shot): scheduler = SchedulerSingleton.get_instance() try: return scheduler.schedule_flow( - "build_schema", input_text, query_example, few_shot + FlowName.BUILD_SCHEMA, input_text, query_example, few_shot ) - except (TypeError, ValueError) as e: + except Exception as e: # pylint: disable=broad-exception-caught + log.error("Schema generation failed: %s", e) raise gr.Error(f"Schema generation failed: {e}") diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py index 301a6bdab..67904a445 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py @@ -22,6 +22,7 @@ import gradio as gr from hugegraph_llm.config import resource_path, huge_settings, llm_settings +from hugegraph_llm.flows import FlowName from hugegraph_llm.indices.vector_index import VectorIndex from hugegraph_llm.models.embeddings.init_embedding import model_map from hugegraph_llm.flows.scheduler import SchedulerSingleton @@ -50,7 +51,9 @@ def read_documents(input_file, input_text): texts.append(text) elif full_path.endswith(".pdf"): # TODO: support PDF file - raise gr.Error("PDF will be supported later! Try to upload text/docx now") + raise gr.Error( + "PDF will be supported later! Try to upload text/docx now" + ) else: raise gr.Error("Please input txt or docx file.") else: @@ -60,7 +63,9 @@ def read_documents(input_file, input_text): # pylint: disable=C0301 def get_vector_index_info(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) filename_prefix = get_filename_prefix( llm_settings.embedding_type, model_map.get(llm_settings.embedding_type) ) @@ -87,11 +92,15 @@ def get_vector_index_info(): def clean_vector_index(): - folder_name = get_index_folder_name(huge_settings.graph_name, huge_settings.graph_space) + folder_name = get_index_folder_name( + huge_settings.graph_name, huge_settings.graph_space + ) filename_prefix = get_filename_prefix( llm_settings.embedding_type, model_map.get(llm_settings.embedding_type) ) - VectorIndex.clean(str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix) + VectorIndex.clean( + str(os.path.join(resource_path, folder_name, "chunks")), filename_prefix + ) gr.Info("Clean vector index successfully!") @@ -100,4 +109,4 @@ def build_vector_index(input_file, input_text): raise gr.Error("Please only choose one between file and text.") texts = read_documents(input_file, input_text) scheduler = SchedulerSingleton.get_instance() - return scheduler.schedule_flow("build_vector_index", texts) + return scheduler.schedule_flow(FlowName.BUILD_VECTOR_INDEX, texts) diff --git a/hugegraph-ml/pyproject.toml b/hugegraph-ml/pyproject.toml index 6d46ba74c..929eb3aa1 100644 --- a/hugegraph-ml/pyproject.toml +++ b/hugegraph-ml/pyproject.toml @@ -22,7 +22,7 @@ build-backend = "hatchling.build" [project] name = "hugegraph-ml" -version = "1.5.0" +version = "1.7.0" description = "Machine learning extensions for Apache HugeGraph." authors = [ { name = "Apache HugeGraph Contributors", email = "dev@hugegraph.apache.org" }, diff --git a/hugegraph-python-client/pyproject.toml b/hugegraph-python-client/pyproject.toml index 81565d9ab..ddae125d8 100644 --- a/hugegraph-python-client/pyproject.toml +++ b/hugegraph-python-client/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "hugegraph-python-client" -version = "1.5.0" +version = "1.7.0" description = "A Python SDK for Apache HugeGraph Database." authors = [ { name = "Apache HugeGraph Contributors", email = "dev@hugegraph.apache.org" }, diff --git a/pyproject.toml b/pyproject.toml index 8bcf58929..2dd4161d4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "hugegraph-ai" -version = "1.5.0" +version = "1.7.0" description = "A repository for AI-related projects for Apache HugeGraph." authors = [ { name = "Apache HugeGraph Contributors", email = "dev@hugegraph.apache.org" }, diff --git a/scripts/build_llm_image.sh b/scripts/build_llm_image.sh old mode 100644 new mode 100755 index 42aa36e39..7425b3df9 --- a/scripts/build_llm_image.sh +++ b/scripts/build_llm_image.sh @@ -18,7 +18,7 @@ set -e -tag="1.5.0" +tag="1.7.0" script_dir=$(realpath "$(dirname "$0")") diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/design.md b/spec/hugegraph-llm/fixed_flow/design.md similarity index 99% rename from .vibedev/spec/hugegraph-llm/fixed_flow/design.md rename to spec/hugegraph-llm/fixed_flow/design.md index c5777236d..5ad64407a 100644 --- a/.vibedev/spec/hugegraph-llm/fixed_flow/design.md +++ b/spec/hugegraph-llm/fixed_flow/design.md @@ -202,7 +202,7 @@ flowchart TD - `BuildVectorIndexFlow`: 向量索引构建工作流 - `GraphExtractFlow`: 图抽取工作流 - `ImportGraphDataFlow`: 图数据导入工作流 - - `UpdateVidEmbeddingsFlows`: 向量更新工作流 + - `UpdateVidEmbeddingsFlow`: 向量更新工作流 - `GetGraphIndexInfoFlow`: 图索引信息获取工作流 - `BuildSchemaFlow`: 模式构建工作流 - `PromptGenerateFlow`: 提示词生成工作流 @@ -407,7 +407,6 @@ class GraphExtractFlow(BaseFlow): prepared_input.split_type = "document" prepared_input.example_prompt = example_prompt prepared_input.schema = schema - prepare_schema(prepared_input, schema) return def build_flow(self, schema, texts, example_prompt, extract_type): diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/requirements.md b/spec/hugegraph-llm/fixed_flow/requirements.md similarity index 100% rename from .vibedev/spec/hugegraph-llm/fixed_flow/requirements.md rename to spec/hugegraph-llm/fixed_flow/requirements.md diff --git a/.vibedev/spec/hugegraph-llm/fixed_flow/tasks.md b/spec/hugegraph-llm/fixed_flow/tasks.md similarity index 100% rename from .vibedev/spec/hugegraph-llm/fixed_flow/tasks.md rename to spec/hugegraph-llm/fixed_flow/tasks.md diff --git a/style/pylint.conf b/style/pylint.conf index 6ccb7a078..4fb3a17c2 100644 --- a/style/pylint.conf +++ b/style/pylint.conf @@ -476,6 +476,7 @@ disable=raw-checker-failed, # it should appear only once). See also the "--disable" option for examples. enable= +extension-pkg-whitelist=PyCGraph [METHOD_ARGS] @@ -596,7 +597,8 @@ contextmanager-decorators=contextlib.contextmanager # List of members which are set dynamically and missed by pylint inference # system, and so shouldn't trigger E1101 when accessed. Python regular # expressions are accepted. -generated-members= +ignored-modules=PyCGraph +generated-members=PyCGraph.* # Tells whether to warn about missing members when the owner of the attribute # is inferred to be None. diff --git a/vermeer-python-client/pyproject.toml b/vermeer-python-client/pyproject.toml index 986010899..d60acc075 100644 --- a/vermeer-python-client/pyproject.toml +++ b/vermeer-python-client/pyproject.toml @@ -17,7 +17,7 @@ [project] name = "vermeer-python-client" -version = "1.5.0" # Independently managed version for the vermeer-python-client package +version = "1.7.0" # Independently managed version for the vermeer-python-client package description = "A Python client library for interacting with Vermeer, a tool for managing and analyzing large-scale graph data." authors = [ { name = "Apache HugeGraph Contributors", email = "dev@hugegraph.apache.org" } @@ -33,7 +33,7 @@ dependencies = [ "setuptools", "urllib3", "rich", - + # Vermeer specific dependencies "python-dateutil", ]