From 090b043d4ef1f37182a34ee8d0915293fdf7084d Mon Sep 17 00:00:00 2001 From: appointat <65004114+Appointat@users.noreply.github.com> Date: Sat, 28 Feb 2026 10:38:47 +0800 Subject: [PATCH 1/4] feat: add CASTS and Light Memory components - Introduced `CastsRestClient` for handling CASTS decision requests. - Added `CastsVector` class for CASTS traversal instructions. - Updated `VectorType` enum to include `CastsVector`. - Created `LightMemConfig` for configuration management of Light Memory. - Implemented `LightMemRestClient` for writing and recalling memory. - Developed `MemoryClient` to facilitate memory operations. - Added request and response classes for memory operations: `MemoryWriteRequest`, `MemoryWriteResponse`, `MemoryRecallRequest`, and `MemoryRecallResponse`. - Introduced `AiEnvelope`, `AiError`, `AiHttpClient`, `AiResponse`, `AiScope`, and `AiTrace` for protocol handling. --- .github/workflows/ci-py311.yml | 122 +++-- geaflow-ai/plugins/casts/.gitignore | 2 +- geaflow-ai/plugins/casts/AGENTS.md | 71 +++ geaflow-ai/plugins/casts/README.md | 130 ++++- geaflow-ai/plugins/casts/api/__init__.py | 18 + geaflow-ai/plugins/casts/api/app.py | 81 ++++ geaflow-ai/plugins/casts/api/envelope.py | 65 +++ geaflow-ai/plugins/casts/api/schema.py | 74 +++ geaflow-ai/plugins/casts/api/service.py | 341 +++++++++++++ geaflow-ai/plugins/casts/core/config.py | 10 +- geaflow-ai/plugins/casts/core/interfaces.py | 5 + geaflow-ai/plugins/casts/core/models.py | 12 +- geaflow-ai/plugins/casts/core/schema.py | 4 +- .../plugins/casts/core/strategy_cache.py | 2 +- .../casts/{data => harness}/__init__.py | 0 .../{simulation => harness/data}/__init__.py | 0 .../{ => harness}/data/graph_generator.py | 16 +- .../{ => harness}/data/real_graph_loader.py | 6 +- .../casts/{ => harness}/data/sources.py | 118 +++-- .../casts/harness/simulation/__init__.py | 17 + .../casts/{ => harness}/simulation/engine.py | 20 +- .../{ => harness}/simulation/evaluator.py | 11 +- .../{ => harness}/simulation/executor.py | 7 +- .../casts/{ => harness}/simulation/metrics.py | 47 +- .../casts/{ => harness}/simulation/runner.py | 16 +- .../{ => harness}/simulation/visualizer.py | 14 +- geaflow-ai/plugins/casts/pyproject.toml | 21 +- geaflow-ai/plugins/casts/scripts/smoke.sh | 161 +++++++ .../plugins/casts/services/embedding.py | 14 +- .../plugins/casts/services/llm_oracle.py | 46 +- geaflow-ai/plugins/casts/tests/conftest.py | 6 + geaflow-ai/plugins/casts/tests/test_api.py | 90 ++++ .../casts/tests/test_execution_lifecycle.py | 106 ++--- .../tests/test_gremlin_step_state_machine.py | 70 ++- .../casts/tests/test_lifecycle_integration.py | 262 ++++++---- .../casts/tests/test_metrics_collector.py | 18 +- .../casts/tests/test_signature_abstraction.py | 26 +- .../plugins/casts/tests/test_simple_path.py | 12 +- .../tests/test_starting_node_selection.py | 61 +-- .../casts/tests/test_threshold_calculation.py | 93 ++-- geaflow-ai/plugins/lightmem/.gitignore | 19 + geaflow-ai/plugins/lightmem/api/__init__.py | 18 + geaflow-ai/plugins/lightmem/api/app.py | 137 ++++++ geaflow-ai/plugins/lightmem/api/envelope.py | 67 +++ geaflow-ai/plugins/lightmem/api/py.typed | 2 + geaflow-ai/plugins/lightmem/core/__init__.py | 18 + .../geaflow_ai_lightmem.egg-info/PKG-INFO | 8 + .../geaflow_ai_lightmem.egg-info/SOURCES.txt | 12 + .../dependency_links.txt | 1 + .../geaflow_ai_lightmem.egg-info/requires.txt | 4 + .../top_level.txt | 1 + geaflow-ai/plugins/lightmem/core/ledger.py | 83 ++++ .../plugins/lightmem/core/memory_kernel.py | 119 +++++ geaflow-ai/plugins/lightmem/core/py.typed | 2 + geaflow-ai/plugins/lightmem/core/types.py | 79 +++ geaflow-ai/plugins/lightmem/core/views.py | 108 +++++ geaflow-ai/plugins/lightmem/pyproject.toml | 50 ++ geaflow-ai/plugins/lightmem/scripts/smoke.sh | 157 ++++++ geaflow-ai/plugins/lightmem/tests/test_api.py | 83 ++++ .../tests/test_write_recall_provenance.py | 36 ++ .../geaflow/ai/GeaFlowMemoryServer.java | 120 +++++ .../apache/geaflow/ai/GraphMemoryServer.java | 31 +- .../apache/geaflow/ai/casts/CastsConfig.java | 55 +++ .../geaflow/ai/casts/CastsDecisionParser.java | 117 +++++ .../ai/casts/CastsDecisionRequest.java | 67 +++ .../ai/casts/CastsDecisionResponse.java | 37 ++ .../geaflow/ai/casts/CastsOperator.java | 449 ++++++++++++++++++ .../geaflow/ai/casts/CastsRestClient.java | 42 ++ .../geaflow/ai/index/vector/CastsVector.java | 64 +++ .../geaflow/ai/index/vector/VectorType.java | 3 +- .../geaflow/ai/memory/LightMemConfig.java | 59 +++ .../geaflow/ai/memory/LightMemRestClient.java | 49 ++ .../geaflow/ai/memory/MemoryClient.java | 43 ++ .../ai/memory/MemoryRecallRequest.java | 34 ++ .../ai/memory/MemoryRecallResponse.java | 57 +++ .../geaflow/ai/memory/MemoryWriteRequest.java | 47 ++ .../ai/memory/MemoryWriteResponse.java | 45 ++ .../geaflow/ai/protocol/AiEnvelope.java | 46 ++ .../apache/geaflow/ai/protocol/AiError.java | 26 + .../geaflow/ai/protocol/AiHttpClient.java | 95 ++++ .../geaflow/ai/protocol/AiResponse.java | 42 ++ .../apache/geaflow/ai/protocol/AiScope.java | 47 ++ .../apache/geaflow/ai/protocol/AiTrace.java | 44 ++ 83 files changed, 4274 insertions(+), 544 deletions(-) create mode 100644 geaflow-ai/plugins/casts/AGENTS.md create mode 100644 geaflow-ai/plugins/casts/api/__init__.py create mode 100644 geaflow-ai/plugins/casts/api/app.py create mode 100644 geaflow-ai/plugins/casts/api/envelope.py create mode 100644 geaflow-ai/plugins/casts/api/schema.py create mode 100644 geaflow-ai/plugins/casts/api/service.py rename geaflow-ai/plugins/casts/{data => harness}/__init__.py (100%) rename geaflow-ai/plugins/casts/{simulation => harness/data}/__init__.py (100%) rename geaflow-ai/plugins/casts/{ => harness}/data/graph_generator.py (94%) rename geaflow-ai/plugins/casts/{ => harness}/data/real_graph_loader.py (97%) rename geaflow-ai/plugins/casts/{ => harness}/data/sources.py (92%) create mode 100644 geaflow-ai/plugins/casts/harness/simulation/__init__.py rename geaflow-ai/plugins/casts/{ => harness}/simulation/engine.py (97%) rename geaflow-ai/plugins/casts/{ => harness}/simulation/evaluator.py (98%) rename geaflow-ai/plugins/casts/{ => harness}/simulation/executor.py (98%) rename geaflow-ai/plugins/casts/{ => harness}/simulation/metrics.py (91%) rename geaflow-ai/plugins/casts/{ => harness}/simulation/runner.py (91%) rename geaflow-ai/plugins/casts/{ => harness}/simulation/visualizer.py (97%) create mode 100755 geaflow-ai/plugins/casts/scripts/smoke.sh create mode 100644 geaflow-ai/plugins/casts/tests/test_api.py create mode 100644 geaflow-ai/plugins/lightmem/.gitignore create mode 100644 geaflow-ai/plugins/lightmem/api/__init__.py create mode 100644 geaflow-ai/plugins/lightmem/api/app.py create mode 100644 geaflow-ai/plugins/lightmem/api/envelope.py create mode 100644 geaflow-ai/plugins/lightmem/api/py.typed create mode 100644 geaflow-ai/plugins/lightmem/core/__init__.py create mode 100644 geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/PKG-INFO create mode 100644 geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/SOURCES.txt create mode 100644 geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/dependency_links.txt create mode 100644 geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/requires.txt create mode 100644 geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/top_level.txt create mode 100644 geaflow-ai/plugins/lightmem/core/ledger.py create mode 100644 geaflow-ai/plugins/lightmem/core/memory_kernel.py create mode 100644 geaflow-ai/plugins/lightmem/core/py.typed create mode 100644 geaflow-ai/plugins/lightmem/core/types.py create mode 100644 geaflow-ai/plugins/lightmem/core/views.py create mode 100644 geaflow-ai/plugins/lightmem/pyproject.toml create mode 100755 geaflow-ai/plugins/lightmem/scripts/smoke.sh create mode 100644 geaflow-ai/plugins/lightmem/tests/test_api.py create mode 100644 geaflow-ai/plugins/lightmem/tests/test_write_recall_provenance.py create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsConfig.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionParser.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionRequest.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionResponse.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsOperator.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsRestClient.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/CastsVector.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemConfig.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemRestClient.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryClient.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryRecallRequest.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryRecallResponse.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryWriteRequest.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryWriteResponse.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiEnvelope.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiError.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiHttpClient.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiResponse.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiScope.java create mode 100644 geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiTrace.java diff --git a/.github/workflows/ci-py311.yml b/.github/workflows/ci-py311.yml index 6ae99c64d..bece3d264 100644 --- a/.github/workflows/ci-py311.yml +++ b/.github/workflows/ci-py311.yml @@ -1,29 +1,93 @@ - name: CASTS Python Tests - - on: - pull_request: - paths: - - "geaflow-ai/plugins/casts/**" - push: - branches: ["master"] - paths: - - "geaflow-ai/plugins/casts/**" - workflow_dispatch: - - jobs: - tests: - runs-on: ubuntu-latest - defaults: - run: - working-directory: geaflow-ai/plugins/casts - steps: - - uses: actions/checkout@v4 - - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - name: Install uv - run: pip install uv - - name: Sync deps - run: uv sync - - name: Run tests - run: uv run pytest +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +################################################################################ + +name: Python CI (uv + 3.11) + +on: + pull_request: + paths: + - "geaflow-ai/plugins/casts/**" + - "geaflow-ai/plugins/lightmem/**" + - ".github/workflows/ci-py311.yml" + push: + branches: ["master"] + paths: + - "geaflow-ai/plugins/casts/**" + - "geaflow-ai/plugins/lightmem/**" + - ".github/workflows/ci-py311.yml" + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event.number || github.run_id }} + cancel-in-progress: true + +jobs: + tests: + name: ${{ matrix.plugin }} (py${{ matrix.python-version }}) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.11"] + plugin: ["casts", "lightmem"] + + defaults: + run: + shell: bash + working-directory: geaflow-ai/plugins/${{ matrix.plugin }} + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Setup uv + uses: astral-sh/setup-uv@v4 + with: + version: "0.9.17" + enable-cache: true + + - name: Sync deps + run: | + python -m venv .venv + uv sync --extra dev --no-editable + + - name: Lint (ruff) + run: | + uv run --no-editable ruff check . + + - name: Format check (ruff) + run: | + uv run --no-editable ruff format --check . + + - name: Type check (mypy) + run: | + if [[ "${{ matrix.plugin }}" == "casts" ]]; then + uv run --no-editable mypy -p api -p core -p services -p harness + else + uv run --no-editable mypy -p lightmem -p api + fi + + - name: Run tests + run: | + uv run --no-editable pytest -q diff --git a/geaflow-ai/plugins/casts/.gitignore b/geaflow-ai/plugins/casts/.gitignore index 570acfcd7..fd0bedf4b 100644 --- a/geaflow-ai/plugins/casts/.gitignore +++ b/geaflow-ai/plugins/casts/.gitignore @@ -18,5 +18,5 @@ uv.lock .DS_Store # Data files -data/real_graph_data/ +harness/data/real_graph_data/ casts_traversal_path_req_*.png diff --git a/geaflow-ai/plugins/casts/AGENTS.md b/geaflow-ai/plugins/casts/AGENTS.md new file mode 100644 index 000000000..9de3b8898 --- /dev/null +++ b/geaflow-ai/plugins/casts/AGENTS.md @@ -0,0 +1,71 @@ +# CASTS Agent Instructions (geaflow-ai/plugins/casts) + +This file defines CASTS plugin-local instructions for coding agents. + +## Must-Read (Before You Change Code) + +- Review and follow: `geaflow-ai/plugins/CODE_STYLES.md` + - Treat it as the baseline contract for changes under `geaflow-ai/plugins/casts/`. + - If you need to break a rule, document the reason in the PR/commit message. + +## Repository Layout (What Goes Where) + +- `core/`: deterministic cache + decision logic (no network calls required). +- `services/`: integration code (LLM / embedding / external I/O). +- `harness/`: offline simulation harness (data + executor + evaluator). +- `api/`: production-facing decision service (FastAPI). + - Endpoint: `POST /casts/decision` + - Safety: must degrade conservatively to `decision="stop"` on invalid input or upstream failures. +- `scripts/`: local developer scripts (e.g., smoke tests). +- `tests/`: pytest suite. + +## Local Dev (Python 3.11 + uv) + +We use a per-plugin venv in `.venv/` (gitignored) and a **no-activate** workflow. + +One-time setup: + +```bash +cd geaflow-ai/plugins/casts +[ -d .venv ] || python3.11 -m venv .venv +uv sync --extra dev +``` + +Run tests: + +```bash +cd geaflow-ai/plugins/casts +uv run pytest -q +``` + +Run lint + type checks: + +```bash +cd geaflow-ai/plugins/casts +uv run ruff format --check . +uv run ruff check . +uv run mypy -p api -p core -p services -p harness +``` + +## Run The Service (FastAPI) + +```bash +cd geaflow-ai/plugins/casts +uv sync --extra dev +uv run uvicorn api.app:app --host 127.0.0.1 --port 5001 +``` + +One-click smoke: + +```bash +cd geaflow-ai/plugins/casts +./scripts/smoke.sh +``` + +## Safety Defaults + +- Never enable evaluating LLM-provided predicates in production: + - `LLM_ORACLE_ENABLE_PREDICATE_EVAL` must remain `False` by default. +- `scope` is a hard boundary: + - requests with empty scope must be rejected (or conservatively downgraded). + - CASTS service additionally requires `scope.run_id` for cache isolation; if missing, it downgrades to `stop`. diff --git a/geaflow-ai/plugins/casts/README.md b/geaflow-ai/plugins/casts/README.md index 9c45d0ce5..39b924546 100644 --- a/geaflow-ai/plugins/casts/README.md +++ b/geaflow-ai/plugins/casts/README.md @@ -12,15 +12,41 @@ CASTS stands for **Context-Aware Strategy Cache System**. - Cache traversal strategies (SKUs) to reduce repeated LLM calls. - Separate schema metadata from execution logic. -- Support both synthetic and real-world graph data. +- Support both synthetic and real-world graph data (in the Python harness). - Keep the core cache logic deterministic and testable. +## Two Modes: Production vs. Python Harness + +CASTS is designed so that **decisioning** and **execution** are separate concerns. + +### Production Mode (GeaFlow Java Integration) + +In production, the **GeaFlow Java data plane executes traversal** and CASTS is +used as a *decision service*: + +- Python returns a **single next-step Gremlin-style step string** (e.g. `out('friend')`, `inV()`, `stop`). +- Python does **not** execute multi-hop traversal against graph data. +- Java (`CastsOperator`) executes the step, expands the subgraph, and repeats until + `stop` or `maxDepth`. + +Key integration files in this repo: + +- Java executor/operator: `geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsOperator.java` +- Python decision service: `geaflow-ai/plugins/casts/api/app.py` (`POST /casts/decision`) + +### Python Harness Mode (Offline Simulation) + +The CASTS repo also includes a **Python-only harness** for experiments: + +- It implements its own in-memory data source + traversal executor for evaluation. +- It is intended for offline testing, hit-rate studies, and algorithm iteration. +- It can be misleading if read as “production execution”; treat it as a harness. + ## Module Layout -- `core`: cache, schema, configuration, and core models -- `data`: data sources and graph generation (synthetic + real) -- `services`: embedding + LLM services -- `simulation`: simulation engine and evaluation +- `core`: decision core (cache, models, GremlinStateMachine, validation) +- `services`: embedding + LLM integrations used by decisioning +- `harness`: Python-only data + simulation + executor (not production execution) - `tests`: unit and integration tests ## Repository Placement @@ -30,8 +56,9 @@ plugin, with the Python package located at the module root. ## Configuration (Required) -The following environment variables are required. Missing values raise a -`ValueError` at startup: +These env vars are required to run CASTS with real embedding + LLM services +(typically used by the Python harness). Missing values raise a `ValueError` +when those services are instantiated: - `EMBEDDING_ENDPOINT` - `EMBEDDING_APIKEY` @@ -45,12 +72,16 @@ automatic fallbacks for missing credentials. ## Real Data Loading -The default loader reads CSV files from `data/real_graph_data` (or -`real_graph_data`) and builds a directed graph. You can override this behavior -by providing a custom loader: +The default harness loader reads CSV files from: + +- `harness/data/real_graph_data` (repo default), or +- `real_graph_data` (alternate path), or +- a configured `GraphGeneratorConfig.real_data_dir` + +You can also override loading by providing a custom loader: ```python -from data.graph_generator import GraphGeneratorConfig, GraphGenerator +from harness.data.graph_generator import GraphGeneratorConfig, GraphGenerator def my_loader(config: GraphGeneratorConfig): # return nodes, edges @@ -65,14 +96,63 @@ graph = GraphGenerator(config=config) `InMemoryGraphSchema` caches type-level labels. If you mutate nodes or edges after creation, call `mark_dirty()` or `rebuild()` before querying schema data. +## Local Dev (Python 3.11 + uv) + +In this repo, each Python plugin keeps its own virtual environment at `.venv/` +(gitignored). + +One-time venv creation: + +```bash +cd geaflow-ai/plugins/casts +[ -d .venv ] || python3.11 -m venv .venv +``` + +Sync dependencies: + +```bash +cd geaflow-ai/plugins/casts +uv sync --extra dev +``` + +Notes: + +- You don't need to `source .venv/bin/activate` for normal workflows. + - `uv sync` installs into the project env (`.venv/`) + - `uv run ...` executes inside that env +- If you *do* activate a venv for interactive work, `uv sync --active` forces syncing + into the active environment. + ## Running a Simulation -From the plugins directory (parent of this module): +From the CASTS plugin directory: + +```bash +cd geaflow-ai/plugins/casts +uv sync --extra harness +uv run python -m harness.simulation.runner +``` + +## Running the Service (FastAPI) + +Start the CASTS decision service (used by GeaFlow Java integration): + +```bash +cd geaflow-ai/plugins/casts +uv sync --extra dev +uv run uvicorn api.app:app --host 127.0.0.1 --port 5001 +``` + +Java side configuration (defaults shown): + +- `GEAFLOW_AI_CASTS_URL=http://localhost:5001` +- `GEAFLOW_AI_CASTS_TOKEN=` (optional bearer token) + +One-click smoke test (starts/stops the server as needed): ```bash -cd /Users/kuda/code/geaflow/geaflow-ai/plugins -uv sync -python -m simulation.runner +cd geaflow-ai/plugins/casts +./scripts/smoke.sh ``` ## Tests @@ -80,11 +160,25 @@ python -m simulation.runner Run tests locally: ```bash -uv sync -pytest +cd geaflow-ai/plugins/casts +uv sync --extra dev +uv run pytest -q +``` + +## Lint & Type Check + +Run lint (ruff) and type checks (mypy): + +```bash +cd geaflow-ai/plugins/casts +uv sync --extra dev +uv run ruff format --check . +uv run ruff check . +uv run mypy -p api -p core -p services -p harness ``` -There is no GitHub Actions workflow for this module by default. +CI: `.github/workflows/ci-py311.yml` runs the CASTS + LightMem Python tests on +Python 3.11 via `uv`. ## Documentation diff --git a/geaflow-ai/plugins/casts/api/__init__.py b/geaflow-ai/plugins/casts/api/__init__.py new file mode 100644 index 000000000..279e7b1a2 --- /dev/null +++ b/geaflow-ai/plugins/casts/api/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CASTS HTTP API (FastAPI).""" diff --git a/geaflow-ai/plugins/casts/api/app.py b/geaflow-ai/plugins/casts/api/app.py new file mode 100644 index 000000000..a514dfc83 --- /dev/null +++ b/geaflow-ai/plugins/casts/api/app.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any +import uuid + +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + +from api.envelope import Envelope +from api.service import decide + +app = FastAPI(title="GeaFlow AI CASTS", version="0.1.0") + + +@app.exception_handler(RequestValidationError) +async def _validation_exception_handler(_: Request, exc: RequestValidationError) -> JSONResponse: + trace_id = f"tr_{uuid.uuid4().hex}" + return JSONResponse( + status_code=422, + content={ + "ok": False, + "error": {"code": "INVALID_REQUEST", "message": str(exc)}, + "trace": {"trace_id": trace_id}, + }, + ) + + +@app.exception_handler(Exception) +async def _unhandled_exception_handler(_: Request, exc: Exception) -> JSONResponse: + trace_id = f"tr_{uuid.uuid4().hex}" + return JSONResponse( + status_code=500, + content={ + "ok": False, + "error": { + "code": "INTERNAL_ERROR", + "message": f"internal error ({type(exc).__name__})", + }, + "trace": {"trace_id": trace_id}, + }, + ) + + +@app.get("/health") +async def health() -> dict[str, Any]: + return {"status": "UP"} + + +@app.post("/casts/decision") +async def casts_decision(envelope: Envelope) -> JSONResponse: + trace = envelope.trace.with_defaults() + env = envelope.model_copy(update={"trace": trace}) + + payload = await decide(env) + return JSONResponse( + status_code=200, + content={ + "ok": True, + "api_version": env.api_version, + "trace": trace.model_dump(), + "payload": payload, + }, + ) diff --git a/geaflow-ai/plugins/casts/api/envelope.py b/geaflow-ai/plugins/casts/api/envelope.py new file mode 100644 index 000000000..8c622c173 --- /dev/null +++ b/geaflow-ai/plugins/casts/api/envelope.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import time +from typing import Any +import uuid + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class Scope(BaseModel): + model_config = ConfigDict(extra="ignore") + + tenant_id: str | None = None + user_id: str | None = None + agent_id: str | None = None + run_id: str | None = None + actor_id: str | None = None + + @model_validator(mode="after") + def _scope_required(self) -> Scope: + has_any = any((self.tenant_id, self.user_id, self.agent_id, self.run_id, self.actor_id)) + if not has_any: + raise ValueError("scope is required (tenant_id/user_id/agent_id/run_id/actor_id)") + return self + + +class Trace(BaseModel): + model_config = ConfigDict(extra="ignore") + + trace_id: str | None = None + timestamp: float | None = None + caller: str | None = None + + def with_defaults(self) -> Trace: + return Trace( + trace_id=self.trace_id or f"tr_{uuid.uuid4().hex}", + timestamp=self.timestamp if self.timestamp is not None else time.time(), + caller=self.caller, + ) + + +class Envelope(BaseModel): + model_config = ConfigDict(extra="ignore") + + api_version: str = "v1" + scope: Scope + trace: Trace = Field(default_factory=Trace) + payload: dict[str, Any] = Field(default_factory=dict) diff --git a/geaflow-ai/plugins/casts/api/schema.py b/geaflow-ai/plugins/casts/api/schema.py new file mode 100644 index 000000000..c243c21be --- /dev/null +++ b/geaflow-ai/plugins/casts/api/schema.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any + +from core.interfaces import GraphSchema +from core.types import JsonDict + + +class RequestGraphSchema(GraphSchema): + """GraphSchema adapter built from request payload metadata.""" + + def __init__( + self, + *, + schema_fingerprint: str, + valid_outgoing_labels: list[str], + valid_incoming_labels: list[str], + node_types: set[str] | None = None, + edge_labels: set[str] | None = None, + node_schema: dict[str, Any] | None = None, + ) -> None: + self._schema_fingerprint = schema_fingerprint + self._valid_outgoing_labels = list(valid_outgoing_labels) + self._valid_incoming_labels = list(valid_incoming_labels) + self._node_types = set(node_types or []) + self._edge_labels = set(edge_labels or []) + self._node_schema = dict(node_schema or {}) + + if not self._edge_labels: + self._edge_labels = set(self._valid_outgoing_labels) | set(self._valid_incoming_labels) + + @property + def node_types(self) -> set[str]: + return set(self._node_types) + + @property + def edge_labels(self) -> set[str]: + return set(self._edge_labels) + + def get_node_schema(self, node_type: str) -> JsonDict: + if not node_type: + return {} + schema = self._node_schema.get(node_type) + if isinstance(schema, dict): + return schema + return {} + + def get_valid_outgoing_edge_labels(self, node_type: str) -> list[str]: + _ = node_type + return list(self._valid_outgoing_labels) + + def get_valid_incoming_edge_labels(self, node_type: str) -> list[str]: + _ = node_type + return list(self._valid_incoming_labels) + + def validate_edge_label(self, label: str) -> bool: + return label in self._edge_labels diff --git a/geaflow-ai/plugins/casts/api/service.py b/geaflow-ai/plugins/casts/api/service.py new file mode 100644 index 000000000..81babe884 --- /dev/null +++ b/geaflow-ai/plugins/casts/api/service.py @@ -0,0 +1,341 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from dataclasses import dataclass +import hashlib +import threading +from typing import Any + +import numpy as np +from pydantic import BaseModel, ConfigDict, Field + +from api.envelope import Envelope +from api.schema import RequestGraphSchema +from core.config import DefaultConfiguration +from core.gremlin_state import GremlinStateMachine +from core.interfaces import Configuration, EmbeddingServiceProtocol +from core.models import Context, StrategyKnowledgeUnit +from core.strategy_cache import StrategyCache +from services.embedding import EmbeddingService +from services.llm_oracle import LLMOracle + + +class _RequestScopedConfig(Configuration): + def __init__(self, base: Configuration, overrides: dict[str, Any]): + self._base = base + self._overrides = dict(overrides) + + def get(self, key: str, default: Any) -> Any: + if key in self._overrides: + return self._overrides[key] + return self._base.get(key, default) + + def get_int(self, key: str, default: int = 0) -> int: + return int(self.get(key, default)) + + def get_float(self, key: str, default: float = 0.0) -> float: + return float(self.get(key, default)) + + def get_bool(self, key: str, default: bool = False) -> bool: + return bool(self.get(key, default)) + + def get_str(self, key: str, default: str = "") -> str: + return str(self.get(key, default)) + + def get_llm_config(self) -> dict[str, str]: + return self._base.get_llm_config() + + def get_embedding_config(self) -> dict[str, str]: + return self._base.get_embedding_config() + + +class _LocalEmbeddingService: + def __init__(self, dimension: int = 64) -> None: + self._dimension = max(8, int(dimension)) + + async def embed_text(self, text: str) -> np.ndarray: + digest = hashlib.sha256(text.encode("utf-8")).digest() + raw = np.frombuffer(digest, dtype=np.uint8).astype(np.float32) + tiled = np.resize(raw, self._dimension) + norm = np.linalg.norm(tiled) + return tiled if norm == 0 else tiled / norm + + async def embed_properties(self, properties: dict[str, Any]) -> np.ndarray: + parts = [f"{k}={properties[k]}" for k in sorted(properties.keys())] + return await self.embed_text("|".join(parts)) + + +class _TraversalPayload(BaseModel): + model_config = ConfigDict(extra="ignore") + + structural_signature: str + step_index: int | None = None + + +class _NodePayload(BaseModel): + model_config = ConfigDict(extra="ignore") + + label: str + properties: dict[str, Any] = Field(default_factory=dict) + + +class _GraphSchemaPayload(BaseModel): + model_config = ConfigDict(extra="ignore") + + schema_fingerprint: str + valid_outgoing_labels: list[str] = Field(default_factory=list) + valid_incoming_labels: list[str] = Field(default_factory=list) + schema_summary: str | None = None + node_types: list[str] | None = None + edge_labels: list[str] | None = None + node_schema: dict[str, Any] | None = None + + +class CastsDecisionPayload(BaseModel): + model_config = ConfigDict(extra="ignore") + + goal: str + max_depth: int | None = None + traversal: _TraversalPayload + node: _NodePayload + graph_schema: _GraphSchemaPayload + constraints: dict[str, Any] | None = None + + +@dataclass +class _CastsRuntime: + cache_namespace: str + schema_fingerprint: str + cache: StrategyCache + llm_oracle: LLMOracle | None + + +class _CastsRuntimeManager: + def __init__(self) -> None: + self._lock = threading.Lock() + self._runtimes: dict[str, _CastsRuntime] = {} + + def get_runtime(self, *, run_id: str, schema_fingerprint: str) -> _CastsRuntime: + key = f"{run_id}|{schema_fingerprint}" + with self._lock: + runtime = self._runtimes.get(key) + if runtime is not None: + return runtime + + base_config = DefaultConfiguration() + config = _RequestScopedConfig( + base_config, + overrides={ + "CACHE_SCHEMA_FINGERPRINT": schema_fingerprint, + # Never eval arbitrary code from LLM predicates by default. + "LLM_ORACLE_ENABLE_PREDICATE_EVAL": False, + }, + ) + + try: + embed: EmbeddingServiceProtocol = EmbeddingService(config) + except Exception: + embed = _LocalEmbeddingService() + + cache = StrategyCache(embed, config) + + try: + oracle = LLMOracle(embed, config) + except Exception: + oracle = None + + runtime = _CastsRuntime( + cache_namespace=run_id, + schema_fingerprint=schema_fingerprint, + cache=cache, + llm_oracle=oracle, + ) + self._runtimes[key] = runtime + return runtime + + +_RUNTIME_MANAGER = _CastsRuntimeManager() + + +def _schema_from_payload(p: _GraphSchemaPayload) -> RequestGraphSchema: + return RequestGraphSchema( + schema_fingerprint=p.schema_fingerprint, + valid_outgoing_labels=p.valid_outgoing_labels, + valid_incoming_labels=p.valid_incoming_labels, + node_types=set(p.node_types or []), + edge_labels=set(p.edge_labels or []), + node_schema=p.node_schema or {}, + ) + + +def _ensure_node_type(properties: dict[str, Any], fallback_label: str) -> dict[str, Any]: + if "type" not in properties and fallback_label: + return {**properties, "type": fallback_label} + return properties + + +def _dedupe_existing(cache: StrategyCache, new_sku: StrategyKnowledgeUnit) -> StrategyKnowledgeUnit: + for existing in cache.knowledge_base: + if ( + existing.structural_signature == new_sku.structural_signature + and existing.goal_template == new_sku.goal_template + and existing.decision_template == new_sku.decision_template + ): + return existing + cache.add_sku(new_sku) + return new_sku + + +async def decide(envelope: Envelope) -> dict[str, Any]: + """Core CASTS decision flow for one traversal step.""" + + try: + payload = CastsDecisionPayload.model_validate(envelope.payload) + except Exception as e: + return { + "decision": "stop", + "match_type": "STOP_INVALID", + "sku_id": None, + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": None, + "cache_namespace": envelope.scope.run_id, + "error": f"invalid_payload: {type(e).__name__}", + }, + } + + run_id = (envelope.scope.run_id or "").strip() + if not run_id: + return { + "decision": "stop", + "match_type": "STOP_INVALID", + "sku_id": None, + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": payload.graph_schema.schema_fingerprint, + "cache_namespace": None, + "error": "scope.run_id is required for CASTS cache isolation", + }, + } + + runtime = _RUNTIME_MANAGER.get_runtime( + run_id=run_id, schema_fingerprint=payload.graph_schema.schema_fingerprint + ) + + schema = _schema_from_payload(payload.graph_schema) + node_properties = _ensure_node_type(payload.node.properties, payload.node.label) + context = Context( + structural_signature=payload.traversal.structural_signature, + properties=node_properties, + goal=payload.goal, + ) + + try: + decision, sku, match_type = await runtime.cache.find_strategy(context) + except Exception as e: + return { + "decision": "stop", + "match_type": "EMBEDDING_ERROR", + "sku_id": None, + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": runtime.schema_fingerprint, + "cache_namespace": runtime.cache_namespace, + "error": f"embedding_error: {type(e).__name__}", + }, + } + if match_type in ("Tier1", "Tier2") and decision: + return { + "decision": decision, + "match_type": match_type, + "sku_id": getattr(sku, "id", None) if sku else None, + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": runtime.schema_fingerprint, + "cache_namespace": runtime.cache_namespace, + }, + } + + if runtime.llm_oracle is None: + return { + "decision": "stop", + "match_type": "LLM_DISABLED", + "sku_id": None, + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": runtime.schema_fingerprint, + "cache_namespace": runtime.cache_namespace, + }, + } + + state, next_step_options = GremlinStateMachine.get_state_and_options( + context.structural_signature, schema, str(context.properties.get("type") or "") + ) + if state == "END" or not next_step_options: + return { + "decision": "stop", + "match_type": "STOP_END", + "sku_id": None, + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": runtime.schema_fingerprint, + "cache_namespace": runtime.cache_namespace, + }, + } + + try: + new_sku = await runtime.llm_oracle.generate_sku(context, schema) + except Exception as e: + return { + "decision": "stop", + "match_type": "LLM_ERROR", + "sku_id": None, + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": runtime.schema_fingerprint, + "cache_namespace": runtime.cache_namespace, + "error": f"llm_oracle_error: {type(e).__name__}", + }, + } + new_sku = _dedupe_existing(runtime.cache, new_sku) + + # Safety: Project decision into allowed options. + if new_sku.decision_template not in next_step_options: + return { + "decision": "stop", + "match_type": "STOP_INVALID", + "sku_id": getattr(new_sku, "id", None), + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": runtime.schema_fingerprint, + "cache_namespace": runtime.cache_namespace, + "error": "decision not in next_step_options", + }, + } + + return { + "decision": new_sku.decision_template, + "match_type": "LLM", + "sku_id": getattr(new_sku, "id", None), + "provenance": { + "trace_id": envelope.trace.trace_id, + "schema_fingerprint": runtime.schema_fingerprint, + "cache_namespace": runtime.cache_namespace, + }, + } diff --git a/geaflow-ai/plugins/casts/core/config.py b/geaflow-ai/plugins/casts/core/config.py index 1e86d6074..8cb3525bc 100644 --- a/geaflow-ai/plugins/casts/core/config.py +++ b/geaflow-ai/plugins/casts/core/config.py @@ -67,7 +67,7 @@ class DefaultConfiguration(Configuration): True # If True, use real data from CSVs; otherwise, generate synthetic data. ) SIMULATION_REAL_DATA_DIR = ( - "data/real_graph_data" # Directory containing the real graph data CSV files. + "harness/data/real_graph_data" # Directory containing the real graph data CSV files. ) SIMULATION_REAL_SUBGRAPH_SIZE = 200 # Max number of nodes to sample for the real data subgraph. SIMULATION_ENABLE_VERIFIER = True # If True, enables the LLM-based path evaluator. @@ -139,6 +139,14 @@ class DefaultConfiguration(Configuration): # Fingerprint for the current graph schema. Changing this will invalidate all existing SKUs. CACHE_SCHEMA_FINGERPRINT = "schema_v1" + # ============================================ + # LLM ORACLE SAFETY + # ============================================ + # If True, allow evaluating LLM-provided predicate code via `eval`. + # This is unsafe for production and should remain False unless running in + # a trusted/offline environment. + LLM_ORACLE_ENABLE_PREDICATE_EVAL = False + # SIGNATURE CONFIGURATION # Signature abstraction level, used as a MATCHING STRATEGY at runtime. # SKUs are always stored in their canonical, most detailed (Level 2) format. diff --git a/geaflow-ai/plugins/casts/core/interfaces.py b/geaflow-ai/plugins/casts/core/interfaces.py index 9c8409a20..10ffc8e2e 100644 --- a/geaflow-ai/plugins/casts/core/interfaces.py +++ b/geaflow-ai/plugins/casts/core/interfaces.py @@ -214,3 +214,8 @@ def get_str(self, key: str, default: str = "") -> str: def get_llm_config(self) -> dict[str, str]: """Get LLM service configuration.""" pass + + @abstractmethod + def get_embedding_config(self) -> dict[str, str]: + """Get embedding service configuration.""" + pass diff --git a/geaflow-ai/plugins/casts/core/models.py b/geaflow-ai/plugins/casts/core/models.py index 6e3c0d117..d7e4503b0 100644 --- a/geaflow-ai/plugins/casts/core/models.py +++ b/geaflow-ai/plugins/casts/core/models.py @@ -36,12 +36,13 @@ def filter_decision_properties(properties: JsonDict) -> JsonDict: @dataclass class Context: """Runtime context c = (structural_signature, properties, goal) - + Represents the current state of a graph traversal: - structural_signature: Current traversal path as a string (e.g., "V().out().in()") - properties: Current node properties (with identity fields filtered out) - goal: Natural language description of the traversal objective """ + structural_signature: str properties: JsonDict goal: str @@ -55,13 +56,13 @@ def safe_properties(self) -> JsonDict: @dataclass class StrategyKnowledgeUnit: """Strategy Knowledge Unit (SKU) - Core building block of the strategy cache. - + Mathematical definition: - SKU = (context_template, decision_template, schema_fingerprint, + SKU = (context_template, decision_template, schema_fingerprint, property_vector, confidence_score, logic_complexity) - + where context_template = (structural_signature, predicate, goal_template) - + Attributes: id: Unique identifier for this SKU structural_signature: s_sku - structural pattern that must match exactly @@ -73,6 +74,7 @@ class StrategyKnowledgeUnit: confidence_score: eta - dynamic confidence score (AIMD updated) logic_complexity: sigma_logic - intrinsic logic complexity measure """ + id: str structural_signature: str predicate: Callable[[JsonDict], bool] diff --git a/geaflow-ai/plugins/casts/core/schema.py b/geaflow-ai/plugins/casts/core/schema.py index 54ec7b6d5..a52b402ee 100644 --- a/geaflow-ai/plugins/casts/core/schema.py +++ b/geaflow-ai/plugins/casts/core/schema.py @@ -38,9 +38,7 @@ class SchemaState(str, Enum): class InMemoryGraphSchema(GraphSchema): """In-memory implementation of GraphSchema for CASTS data sources.""" - def __init__( - self, nodes: GraphNodes, edges: GraphEdges - ): + def __init__(self, nodes: GraphNodes, edges: GraphEdges): """Initialize schema from graph data. Args: diff --git a/geaflow-ai/plugins/casts/core/strategy_cache.py b/geaflow-ai/plugins/casts/core/strategy_cache.py index 179733252..33aa30128 100644 --- a/geaflow-ai/plugins/casts/core/strategy_cache.py +++ b/geaflow-ai/plugins/casts/core/strategy_cache.py @@ -153,7 +153,7 @@ def _to_abstract_signature(self, signature: str) -> str: return signature abstract_parts = [] - steps = signature.split('.') + steps = signature.split(".") for i, step in enumerate(steps): if i == 0: abstract_parts.append(step) diff --git a/geaflow-ai/plugins/casts/data/__init__.py b/geaflow-ai/plugins/casts/harness/__init__.py similarity index 100% rename from geaflow-ai/plugins/casts/data/__init__.py rename to geaflow-ai/plugins/casts/harness/__init__.py diff --git a/geaflow-ai/plugins/casts/simulation/__init__.py b/geaflow-ai/plugins/casts/harness/data/__init__.py similarity index 100% rename from geaflow-ai/plugins/casts/simulation/__init__.py rename to geaflow-ai/plugins/casts/harness/data/__init__.py diff --git a/geaflow-ai/plugins/casts/data/graph_generator.py b/geaflow-ai/plugins/casts/harness/data/graph_generator.py similarity index 94% rename from geaflow-ai/plugins/casts/data/graph_generator.py rename to geaflow-ai/plugins/casts/harness/data/graph_generator.py index 60ffb5d45..704ac1564 100644 --- a/geaflow-ai/plugins/casts/data/graph_generator.py +++ b/geaflow-ai/plugins/casts/harness/data/graph_generator.py @@ -20,7 +20,8 @@ This module supports two data sources: 1. Synthetic graph data with Zipf-like distribution (default). -2. Real transaction/relationship data loaded from CSV files under ``real_graph_data/`` +2. Real transaction/relationship data loaded from CSV files under ``real_graph_data/`` (or + the repo-default ``harness/data/real_graph_data/``) (or a custom loader via ``GraphGeneratorConfig.real_data_loader``). Use :class:`GraphGenerator` as the unified in-memory representation. The simulation @@ -34,7 +35,7 @@ from core.constants import EDGE_LABEL_KEY, EDGE_TARGET_KEY, NODE_ID_KEY, NODE_TYPE_KEY from core.types import GraphEdges, GraphNodes, JsonDict -from data.real_graph_loader import RealGraphLoader, default_real_graph_loader +from harness.data.real_graph_loader import RealGraphLoader, default_real_graph_loader @dataclass @@ -106,7 +107,7 @@ def _generate_zipf_data(self, size: int) -> None: ] # Weights approximating 1/k distribution type_weights = [100, 50, 25, 12, 6] - + business_categories = ["retail", "wholesale", "finance", "manufacturing"] regions = ["NA", "EU", "APAC", "LATAM"] risk_levels = ["low", "medium", "high"] @@ -116,7 +117,7 @@ def _generate_zipf_data(self, size: int) -> None: node_type = random.choices(business_types, weights=type_weights, k=1)[0] status = "active" if random.random() < 0.8 else "inactive" age = random.randint(18, 60) - + node: JsonDict = { NODE_ID_KEY: str(i), NODE_TYPE_KEY: node_type, @@ -142,11 +143,8 @@ def _generate_zipf_data(self, size: int) -> None: if self.nodes[str(i)]["type"] == "Retail SME" and random.random() < 0.7: label = "related" elif ( - self.nodes[str(i)]["type"] == "Logistics Partner" - and random.random() < 0.7 + self.nodes[str(i)]["type"] == "Logistics Partner" and random.random() < 0.7 ): label = "friend" - self.edges[str(i)].append( - {EDGE_TARGET_KEY: str(target), EDGE_LABEL_KEY: label} - ) + self.edges[str(i)].append({EDGE_TARGET_KEY: str(target), EDGE_LABEL_KEY: label}) diff --git a/geaflow-ai/plugins/casts/data/real_graph_loader.py b/geaflow-ai/plugins/casts/harness/data/real_graph_loader.py similarity index 97% rename from geaflow-ai/plugins/casts/data/real_graph_loader.py rename to geaflow-ai/plugins/casts/harness/data/real_graph_loader.py index 4e0992cce..6408931c5 100644 --- a/geaflow-ai/plugins/casts/data/real_graph_loader.py +++ b/geaflow-ai/plugins/casts/harness/data/real_graph_loader.py @@ -28,7 +28,7 @@ from core.types import GraphEdges, GraphNodes, JsonDict if TYPE_CHECKING: - from data.graph_generator import GraphGeneratorConfig + from harness.data.graph_generator import GraphGeneratorConfig RealGraphLoader = Callable[ ["GraphGeneratorConfig"], @@ -214,7 +214,7 @@ def add_undirected(u: str, v: str) -> None: def _resolve_data_dir(real_data_dir: str | None) -> Path: """Resolve the directory that contains real graph CSV files.""" - project_root = Path(__file__).resolve().parents[1] + project_root = Path(__file__).resolve().parents[2] if real_data_dir: configured = Path(real_data_dir) @@ -225,7 +225,7 @@ def _resolve_data_dir(real_data_dir: str | None) -> Path: return configured default_candidates = [ - project_root / "data" / "real_graph_data", + project_root / "harness" / "data" / "real_graph_data", project_root / "real_graph_data", ] for candidate in default_candidates: diff --git a/geaflow-ai/plugins/casts/data/sources.py b/geaflow-ai/plugins/casts/harness/data/sources.py similarity index 92% rename from geaflow-ai/plugins/casts/data/sources.py rename to geaflow-ai/plugins/casts/harness/data/sources.py index 5e78b37f5..ee5c4c2b6 100644 --- a/geaflow-ai/plugins/casts/data/sources.py +++ b/geaflow-ai/plugins/casts/harness/data/sources.py @@ -118,9 +118,7 @@ def __init__(self, node_types: set[str], edge_labels: set[str]): loan = "Loan" if "Loan" in node_types else "loan node" invest = "invest" if "invest" in edge_labels else "invest relation" - guarantee = ( - "guarantee" if "guarantee" in edge_labels else "guarantee relation" - ) + guarantee = "guarantee" if "guarantee" in edge_labels else "guarantee relation" transfer = "transfer" if "transfer" in edge_labels else "transfer relation" withdraw = "withdraw" if "withdraw" in edge_labels else "withdraw relation" repay = "repay" if "repay" in edge_labels else "repay relation" @@ -194,9 +192,7 @@ def select_goal(self, node_type: str | None = None) -> tuple[str, str]: candidates = list(c_tuple) weights = list(w_tuple) - selected_goal, selected_rubric = random.choices( - candidates, weights=weights, k=1 - )[0] + selected_goal, selected_rubric = random.choices(candidates, weights=weights, k=1)[0] return selected_goal, selected_rubric @@ -205,7 +201,7 @@ class SyntheticDataSource(DataSource): def __init__(self, size: int = 30): """Initialize synthetic data source. - + Args: size: Number of nodes to generate """ @@ -243,8 +239,8 @@ def get_neighbors(self, node_id: str, edge_label: str | None = None) -> list[str neighbors = [] for edge in self._edges[node_id]: - if edge_label is None or edge['label'] == edge_label: - neighbors.append(edge['target']) + if edge_label is None or edge["label"] == edge_label: + neighbors.append(edge["target"]) return neighbors def get_schema(self) -> GraphSchema: @@ -323,37 +319,37 @@ def get_starting_nodes( def _generate_zipf_data(self, size: int): """Generate synthetic data following Zipf distribution.""" business_types = [ - 'Retail SME', - 'Logistics Partner', - 'Enterprise Vendor', - 'Regional Distributor', - 'FinTech Startup', + "Retail SME", + "Logistics Partner", + "Enterprise Vendor", + "Regional Distributor", + "FinTech Startup", ] type_weights = [100, 50, 25, 12, 6] - business_categories = ['retail', 'wholesale', 'finance', 'manufacturing'] - regions = ['NA', 'EU', 'APAC', 'LATAM'] - risk_levels = ['low', 'medium', 'high'] + business_categories = ["retail", "wholesale", "finance", "manufacturing"] + regions = ["NA", "EU", "APAC", "LATAM"] + risk_levels = ["low", "medium", "high"] # Generate nodes for i in range(size): node_type = random.choices(business_types, weights=type_weights, k=1)[0] - status = 'active' if random.random() < 0.8 else 'inactive' + status = "active" if random.random() < 0.8 else "inactive" age = random.randint(18, 60) node = { - 'id': str(i), - 'type': node_type, - 'category': random.choice(business_categories), - 'region': random.choice(regions), - 'risk': random.choice(risk_levels), - 'status': status, - 'age': age, + "id": str(i), + "type": node_type, + "category": random.choice(business_categories), + "region": random.choice(regions), + "risk": random.choice(risk_levels), + "status": status, + "age": age, } self._nodes[str(i)] = node # Generate edges with more structured, denser relationship patterns - edge_labels = ['friend', 'supplier', 'partner', 'investor', 'customer'] + edge_labels = ["friend", "supplier", "partner", "investor", "customer"] # Baseline randomness: ensure each node has some edges. for i in range(size): @@ -363,34 +359,34 @@ def _generate_zipf_data(self, size: int): if target_id == str(i): continue label = random.choice(edge_labels) - edge = {'target': target_id, 'label': label} + edge = {"target": target_id, "label": label} self._edges.setdefault(str(i), []).append(edge) # Structural bias: different business types favor certain relations # to help the LLM learn stable patterns. for i in range(size): src_id = str(i) - node_type = self._nodes[src_id]['type'] + node_type = self._nodes[src_id]["type"] # Retail SME: more customer / supplier edges - if node_type == 'Retail SME': - extra_labels = ['customer', 'supplier'] + if node_type == "Retail SME": + extra_labels = ["customer", "supplier"] extra_edges = 2 # Logistics Partner: more partner / supplier edges - elif node_type == 'Logistics Partner': - extra_labels = ['partner', 'supplier'] + elif node_type == "Logistics Partner": + extra_labels = ["partner", "supplier"] extra_edges = 2 # Enterprise Vendor: more supplier / investor edges - elif node_type == 'Enterprise Vendor': - extra_labels = ['supplier', 'investor'] + elif node_type == "Enterprise Vendor": + extra_labels = ["supplier", "investor"] extra_edges = 2 # Regional Distributor: more partner / customer edges - elif node_type == 'Regional Distributor': - extra_labels = ['partner', 'customer'] + elif node_type == "Regional Distributor": + extra_labels = ["partner", "customer"] extra_edges = 2 # FinTech Startup: more investor / partner edges else: # 'FinTech Startup' - extra_labels = ['investor', 'partner'] + extra_labels = ["investor", "partner"] extra_edges = 3 # Slightly higher to test deeper paths. for _ in range(extra_edges): @@ -398,7 +394,7 @@ def _generate_zipf_data(self, size: int): if target_id == src_id: continue label = random.choice(extra_labels) - edge = {'target': target_id, 'label': label} + edge = {"target": target_id, "label": label} self._edges.setdefault(src_id, []).append(edge) # Optional: increase global "friend" connectivity to reduce isolated components. @@ -407,7 +403,7 @@ def _generate_zipf_data(self, size: int): if random.random() < 0.3: # 30% of nodes add an extra friend edge. target_id = str(random.randint(0, size - 1)) if target_id != src_id: - edge = {'target': target_id, 'label': 'friend'} + edge = {"target": target_id, "label": "friend"} self._edges.setdefault(src_id, []).append(edge) @@ -465,8 +461,8 @@ def get_neighbors(self, node_id: str, edge_label: str | None = None) -> list[str neighbors = [] for edge in self._edges[node_id]: - if edge_label is None or edge['label'] == edge_label: - neighbors.append(edge['target']) + if edge_label is None or edge["label"] == edge_label: + neighbors.append(edge["target"]) return neighbors def reload(self): @@ -778,12 +774,10 @@ def _add_edge_if_not_exists(self, src_id, tgt_id, label): # Check if a similar edge already exists for edge in self._edges[src_id]: - if edge['target'] == tgt_id and edge['label'] == label: + if edge["target"] == tgt_id and edge["label"] == label: return # Edge already exists - self._edges[src_id].append({'target': tgt_id, 'label': label}) - - + self._edges[src_id].append({"target": tgt_id, "label": label}) def _load_nodes_from_csv(self, filepath: Path, entity_type: str): """Load nodes from a CSV file using actual column names as attributes.""" @@ -791,32 +785,32 @@ def _load_nodes_from_csv(self, filepath: Path, entity_type: str): return try: - with open(filepath, encoding='utf-8') as f: + with open(filepath, encoding="utf-8") as f: # Use DictReader to get actual column names - reader = csv.DictReader(f, delimiter='|') + reader = csv.DictReader(f, delimiter="|") if not reader.fieldnames: return # First column is the ID field id_field = reader.fieldnames[0] - + for row in reader: raw_id = row.get(id_field) if not raw_id: # Skip empty IDs continue - + node_id = f"{entity_type}_{raw_id}" node = { - 'id': node_id, - 'type': entity_type, - 'raw_id': raw_id, + "id": node_id, + "type": entity_type, + "raw_id": raw_id, } - + # Add all fields using their real column names for field_name, field_value in row.items(): if field_name != id_field and field_value: node[field_name] = field_value - + self._nodes[node_id] = node except Exception as e: print(f"Warning: Error loading {filepath}: {e}") @@ -827,8 +821,8 @@ def _load_edges_from_csv(self, filepath: Path, from_type: str, to_type: str, lab return try: - with open(filepath, encoding='utf-8') as f: - reader = csv.reader(f, delimiter='|') + with open(filepath, encoding="utf-8") as f: + reader = csv.reader(f, delimiter="|") for row in reader: if len(row) >= 2: src_id = f"{from_type}_{row[0]}" @@ -836,7 +830,7 @@ def _load_edges_from_csv(self, filepath: Path, from_type: str, to_type: str, lab # Only add edge if both nodes exist if src_id in self._nodes and tgt_id in self._nodes: - edge = {'target': tgt_id, 'label': label} + edge = {"target": tgt_id, "label": label} if src_id not in self._edges: self._edges[src_id] = [] self._edges[src_id].append(edge) @@ -860,7 +854,7 @@ def _sample_subgraph(self): G.add_node(node_id, **node) for src_id, edge_List in self._edges.items(): for edge in edge_List: - G.add_edge(src_id, edge['target'], label=edge['label']) + G.add_edge(src_id, edge["target"], label=edge["label"]) # Find largest connected component if not G.nodes(): @@ -914,9 +908,7 @@ def _sample_subgraph(self): random.shuffle(deduped) deduped.sort( key=lambda nid: ( - 0 - if G.nodes[nid].get("type", "Unknown") not in seen_types - else 1 + 0 if G.nodes[nid].get("type", "Unknown") not in seen_types else 1 ) ) @@ -935,9 +927,7 @@ def _sample_subgraph(self): # Filter nodes and edges to sampled subset self._nodes = { - node_id: node - for node_id, node in self._nodes.items() - if node_id in sampled_nodes + node_id: node for node_id, node in self._nodes.items() if node_id in sampled_nodes } self._edges = { src_id: [edge for edge in edges if edge["target"] in sampled_nodes] diff --git a/geaflow-ai/plugins/casts/harness/simulation/__init__.py b/geaflow-ai/plugins/casts/harness/simulation/__init__.py new file mode 100644 index 000000000..245692337 --- /dev/null +++ b/geaflow-ai/plugins/casts/harness/simulation/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + diff --git a/geaflow-ai/plugins/casts/simulation/engine.py b/geaflow-ai/plugins/casts/harness/simulation/engine.py similarity index 97% rename from geaflow-ai/plugins/casts/simulation/engine.py rename to geaflow-ai/plugins/casts/harness/simulation/engine.py index c15f0d172..3cef779b1 100644 --- a/geaflow-ai/plugins/casts/simulation/engine.py +++ b/geaflow-ai/plugins/casts/harness/simulation/engine.py @@ -25,9 +25,9 @@ from core.models import Context, StrategyKnowledgeUnit from core.strategy_cache import StrategyCache from core.types import TraversalResult +from harness.simulation.executor import TraversalExecutor +from harness.simulation.metrics import MetricsCollector, PathStep from services.llm_oracle import LLMOracle -from simulation.executor import TraversalExecutor -from simulation.metrics import MetricsCollector, PathStep CyclePenaltyMode = Literal["NONE", "PUNISH", "STOP"] @@ -94,9 +94,7 @@ async def run_epoch( sample_nodes = [] # 4. Initialize traversers for the starting nodes - current_layer: list[ - tuple[str, str, str, int, int | None, str | None, str | None] - ] = [] + current_layer: list[tuple[str, str, str, int, int | None, str | None, str | None]] = [] for node_id in sample_nodes: request_id = metrics_collector.initialize_path( epoch, node_id, self.graph.nodes[node_id], goal_text, rubric @@ -167,9 +165,7 @@ def execute_prechecker( raw_cycle_penalty_mode = self.llm_oracle.config.get_str("CYCLE_PENALTY").upper() if raw_cycle_penalty_mode not in ("NONE", "PUNISH", "STOP"): raw_cycle_penalty_mode = "STOP" - cycle_penalty_mode: CyclePenaltyMode = cast( - CyclePenaltyMode, raw_cycle_penalty_mode - ) + cycle_penalty_mode: CyclePenaltyMode = cast(CyclePenaltyMode, raw_cycle_penalty_mode) # Mode: NONE - skip all validation if cycle_penalty_mode == "NONE": @@ -207,9 +203,7 @@ def execute_prechecker( # === VALIDATION 2: Confidence Threshold === # Check if SKU confidence has fallen too low - min_confidence = self.llm_oracle.config.get_float( - "MIN_EXECUTION_CONFIDENCE" - ) + min_confidence = self.llm_oracle.config.get_float("MIN_EXECUTION_CONFIDENCE") if sku.confidence_score < min_confidence: if self.verbose: print( @@ -311,9 +305,7 @@ async def execute_tick( if self.verbose: print(f"\n[Tick {tick}] Processing {len(current_layer)} active traversers") - next_layer: list[ - tuple[str, str, str, int, int | None, str | None, str | None] - ] = [] + next_layer: list[tuple[str, str, str, int, int | None, str | None, str | None]] = [] for idx, traversal_state in enumerate(current_layer): ( diff --git a/geaflow-ai/plugins/casts/simulation/evaluator.py b/geaflow-ai/plugins/casts/harness/simulation/evaluator.py similarity index 98% rename from geaflow-ai/plugins/casts/simulation/evaluator.py rename to geaflow-ai/plugins/casts/harness/simulation/evaluator.py index ec240c0f8..b7d53d413 100644 --- a/geaflow-ai/plugins/casts/simulation/evaluator.py +++ b/geaflow-ai/plugins/casts/harness/simulation/evaluator.py @@ -26,10 +26,11 @@ """ from dataclasses import dataclass, field +from typing import Sequence from core.types import JsonDict +from harness.simulation.metrics import PathInfo, PathStep from services.path_judge import PathJudge -from simulation.metrics import PathInfo, PathStep from utils.helpers import parse_jsons QUERY_MAX_SCORE = 35.0 @@ -312,9 +313,7 @@ def _score_strategy_reusability( return min(STRATEGY_MAX_SCORE, score), detail - def _score_cache_efficiency( - self, match_types: list[str | None] - ) -> tuple[float, JsonDict]: + def _score_cache_efficiency(self, match_types: Sequence[str | None]) -> tuple[float, JsonDict]: detail: JsonDict = {} total = len(match_types) if total == 0: @@ -390,9 +389,7 @@ def _score_decision_consistency( return min(CONSISTENCY_MAX_SCORE, score), detail - def _score_information_utility( - self, props: list[JsonDict] - ) -> tuple[float, JsonDict]: + def _score_information_utility(self, props: list[JsonDict]) -> tuple[float, JsonDict]: detail: JsonDict = {} if not props: return 0.0, {"note": "no_properties"} diff --git a/geaflow-ai/plugins/casts/simulation/executor.py b/geaflow-ai/plugins/casts/harness/simulation/executor.py similarity index 98% rename from geaflow-ai/plugins/casts/simulation/executor.py rename to geaflow-ai/plugins/casts/harness/simulation/executor.py index 54bebf43b..5bd6be366 100644 --- a/geaflow-ai/plugins/casts/simulation/executor.py +++ b/geaflow-ai/plugins/casts/harness/simulation/executor.py @@ -38,8 +38,11 @@ def _ensure_path_history(self, request_id: int, current_node_id: str) -> set[str return self._path_history[request_id] async def execute_decision( - self, current_node_id: str, decision: str, current_signature: str, - request_id: int | None = None + self, + current_node_id: str, + decision: str, + current_signature: str, + request_id: int | None = None, ) -> list[tuple[str, str, tuple[str, str] | None]]: """ Execute a traversal decision and return next nodes with updated signatures. diff --git a/geaflow-ai/plugins/casts/simulation/metrics.py b/geaflow-ai/plugins/casts/harness/simulation/metrics.py similarity index 91% rename from geaflow-ai/plugins/casts/simulation/metrics.py rename to geaflow-ai/plugins/casts/harness/simulation/metrics.py index 2425c90f6..51c386220 100644 --- a/geaflow-ai/plugins/casts/simulation/metrics.py +++ b/geaflow-ai/plugins/casts/harness/simulation/metrics.py @@ -24,6 +24,7 @@ MatchType = Literal["Tier1", "Tier2", ""] + class PathStep(TypedDict): """Recorded traversal step for a request.""" @@ -39,6 +40,7 @@ class PathStep(TypedDict): sku_id: str | None decision: str | None + class PathInfo(TypedDict): """Traversal path metadata and step history.""" @@ -49,6 +51,7 @@ class PathInfo(TypedDict): rubric: str steps: list[PathStep] + class MetricsSummary(TypedDict): """Summary of aggregate simulation metrics.""" @@ -112,9 +115,9 @@ def __init__(self): def record_step(self, match_type: MatchType | None = None) -> None: """Record a traversal step execution.""" self.metrics.total_steps += 1 - if match_type == 'Tier1': + if match_type == "Tier1": self.metrics.tier1_hits += 1 - elif match_type == 'Tier2': + elif match_type == "Tier2": self.metrics.tier2_hits += 1 else: self.metrics.misses += 1 @@ -123,7 +126,7 @@ def record_step(self, match_type: MatchType | None = None) -> None: def record_execution_failure(self) -> None: """Record a failed strategy execution.""" self.metrics.execution_failures += 1 - + def record_sku_eviction(self, count: int = 1) -> None: """Record SKU evictions from cache cleanup.""" self.metrics.sku_evictions += count @@ -146,10 +149,10 @@ def initialize_path( "start_node_props": start_node_props, "goal": goal, "rubric": rubric, - "steps": [] + "steps": [], } return request_id - + def record_path_step( self, request_id: int, @@ -168,21 +171,23 @@ def record_path_step( """Record a step in a traversal path.""" if request_id not in self.paths: return - - self.paths[request_id]["steps"].append({ - "tick": tick, - "node": node_id, - "parent_node": parent_node, - # For visualization only: explicit edge to previous step - "parent_step_index": parent_step_index, - "edge_label": edge_label, - "s": structural_signature, - "g": goal, - "p": dict(properties), - "match_type": match_type, - "sku_id": sku_id, - "decision": decision - }) + + self.paths[request_id]["steps"].append( + { + "tick": tick, + "node": node_id, + "parent_node": parent_node, + # For visualization only: explicit edge to previous step + "parent_step_index": parent_step_index, + "edge_label": edge_label, + "s": structural_signature, + "g": goal, + "p": dict(properties), + "match_type": match_type, + "sku_id": sku_id, + "decision": decision, + } + ) def rollback_steps(self, request_id: int, count: int = 1) -> bool: """ @@ -226,7 +231,7 @@ def get_summary(self) -> MetricsSummary: "sku_evictions": self.metrics.sku_evictions, "hit_rate": self.metrics.hit_rate, } - + def print_summary(self) -> None: """Print a formatted summary of simulation metrics.""" print("\n=== Simulation Results Analysis ===") diff --git a/geaflow-ai/plugins/casts/simulation/runner.py b/geaflow-ai/plugins/casts/harness/simulation/runner.py similarity index 91% rename from geaflow-ai/plugins/casts/simulation/runner.py rename to geaflow-ai/plugins/casts/harness/simulation/runner.py index 541cd1bae..1aa4cfb6e 100644 --- a/geaflow-ai/plugins/casts/simulation/runner.py +++ b/geaflow-ai/plugins/casts/harness/simulation/runner.py @@ -22,14 +22,14 @@ from core.config import DefaultConfiguration from core.strategy_cache import StrategyCache from core.types import JsonDict -from data.sources import DataSourceFactory +from harness.data.sources import DataSourceFactory +from harness.simulation.engine import SimulationEngine +from harness.simulation.evaluator import BatchEvaluator, PathEvaluationScore, PathEvaluator +from harness.simulation.metrics import MetricsCollector +from harness.simulation.visualizer import SimulationVisualizer from services.embedding import EmbeddingService from services.llm_oracle import LLMOracle from services.path_judge import PathJudge -from simulation.engine import SimulationEngine -from simulation.evaluator import BatchEvaluator, PathEvaluationScore, PathEvaluator -from simulation.metrics import MetricsCollector -from simulation.visualizer import SimulationVisualizer async def run_simulation(): @@ -93,14 +93,12 @@ def evaluate_completed_request(request_id: int, metrics_collector: MetricsCollec # Run simulation metrics_collector = await engine.run_simulation( num_epochs=config.get_int("SIMULATION_NUM_EPOCHS"), - on_request_completed=evaluate_completed_request + on_request_completed=evaluate_completed_request, ) # Get sorted SKUs for reporting sorted_skus = sorted( - strategy_cache.knowledge_base, - key=lambda x: x.confidence_score, - reverse=True + strategy_cache.knowledge_base, key=lambda x: x.confidence_score, reverse=True ) # Print results diff --git a/geaflow-ai/plugins/casts/simulation/visualizer.py b/geaflow-ai/plugins/casts/harness/simulation/visualizer.py similarity index 97% rename from geaflow-ai/plugins/casts/simulation/visualizer.py rename to geaflow-ai/plugins/casts/harness/simulation/visualizer.py index 42d26d099..577f99432 100644 --- a/geaflow-ai/plugins/casts/simulation/visualizer.py +++ b/geaflow-ai/plugins/casts/harness/simulation/visualizer.py @@ -24,7 +24,7 @@ from core.interfaces import DataSource from core.models import Context, StrategyKnowledgeUnit from core.strategy_cache import StrategyCache -from simulation.metrics import PathInfo, PathStep, SimulationMetrics +from harness.simulation.metrics import PathInfo, PathStep, SimulationMetrics from utils.helpers import ( calculate_dynamic_similarity_threshold, calculate_tier2_threshold, @@ -119,9 +119,7 @@ def print_knowledge_base_state(sorted_skus: list[StrategyKnowledgeUnit]): print(f" - structural_signature: {sku.structural_signature}") vector_head = sku.property_vector[:3] rounded_head = [round(x, 3) for x in vector_head] - vector_summary = ( - f"Vector(dim={len(sku.property_vector)}, head={rounded_head}...)" - ) + vector_summary = f"Vector(dim={len(sku.property_vector)}, head={rounded_head}...)" print(f" - property_vector: {vector_summary}") print(f" - goal_template: {sku.goal_template}") print(f" - decision_template: {sku.decision_template}") @@ -225,9 +223,7 @@ async def print_all_results( # Generate matplotlib visualizations if graph is provided if graph is not None: - SimulationVisualizer.plot_all_traversal_paths( - paths=paths, graph=graph, show=show_plots - ) + SimulationVisualizer.plot_all_traversal_paths(paths=paths, graph=graph, show=show_plots) @staticmethod def plot_traversal_path( @@ -392,9 +388,7 @@ def plot_traversal_path( return fig @staticmethod - def plot_all_traversal_paths( - paths: dict[int, PathInfo], graph: DataSource, show: bool = True - ): + def plot_all_traversal_paths(paths: dict[int, PathInfo], graph: DataSource, show: bool = True): """Generate matplotlib visualizations for all requests' traversal paths. Args: diff --git a/geaflow-ai/plugins/casts/pyproject.toml b/geaflow-ai/plugins/casts/pyproject.toml index 43aa71d41..a6cabc047 100644 --- a/geaflow-ai/plugins/casts/pyproject.toml +++ b/geaflow-ai/plugins/casts/pyproject.toml @@ -9,8 +9,10 @@ requires-python = ">=3.10" dependencies = [ "openai>=1.86.0", "numpy>=2.0.0", - "matplotlib>=3.8.0", "networkx>=3.2.0", + "fastapi>=0.115.0", + "uvicorn>=0.30.0", + "pydantic>=2.6.0", "python-dotenv>=0.21.0", "pytest>=8.4.0", "pytest-asyncio>=0.24.0", @@ -22,9 +24,13 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest>=8.4.0", + "httpx>=0.27.0", "ruff>=0.11.13", "mypy>=1.18.1", ] +harness = [ + "matplotlib>=3.8.0", +] service = [ "flask==3.1.1", "flask-sqlalchemy==3.1.1", @@ -91,11 +97,16 @@ test = [ where = ["."] include = [ "core", "core.*", - "data", "data.*", "services", "services.*", - "simulation", "simulation.*", + "api", "api.*", "utils", "utils.*", ] -[project.scripts] -casts-sim = "simulation.runner:main" +[tool.mypy] +python_version = "3.11" +show_error_codes = true +pretty = true + +[[tool.mypy.overrides]] +module = ["matplotlib", "matplotlib.*"] +ignore_missing_imports = true diff --git a/geaflow-ai/plugins/casts/scripts/smoke.sh b/geaflow-ai/plugins/casts/scripts/smoke.sh new file mode 100755 index 000000000..9ae1ecfad --- /dev/null +++ b/geaflow-ai/plugins/casts/scripts/smoke.sh @@ -0,0 +1,161 @@ +#!/usr/bin/env bash +set -euo pipefail + +# One-click smoke test for CASTS decision service. +# +# What it does: +# 1) Ensures `.venv` exists and runs `uv sync` into it (no activation required). +# 2) Starts `uvicorn` on HOST:PORT (unless a healthy service is already running there). +# 3) Runs HTTP checks for `/health` and `POST /casts/decision`. +# 4) Shuts down the service if this script started it. + +HOST="${GEAFLOW_AI_CASTS_HOST:-127.0.0.1}" +PORT="${GEAFLOW_AI_CASTS_PORT:-5001}" +TIMEOUT_SECONDS="${GEAFLOW_AI_CASTS_TIMEOUT_SECONDS:-20}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PLUGIN_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" +LOG_DIR="${PLUGIN_DIR}/logs" +LOG_FILE="${LOG_DIR}/smoke_casts.log" + +mkdir -p "${LOG_DIR}" +cd "${PLUGIN_DIR}" + +if [[ ! -d .venv ]]; then + python3.11 -m venv .venv +fi + +VENV_PY="${PLUGIN_DIR}/.venv/bin/python" +if [[ ! -x "${VENV_PY}" ]]; then + echo "[smoke] ERROR: venv python not found at ${VENV_PY}" + exit 1 +fi + +cleanup() { + if [[ "${STARTED_HERE:-false}" == "true" && -n "${SERVER_PID:-}" ]]; then + kill "${SERVER_PID}" >/dev/null 2>&1 || true + for _ in {1..50}; do + if kill -0 "${SERVER_PID}" >/dev/null 2>&1; then + sleep 0.1 + else + break + fi + done + fi +} +trap cleanup EXIT + +echo "[smoke] Sync deps (casts venv) ..." +uv sync --extra dev + +BASE_URL="http://${HOST}:${PORT}" +STARTED_HERE="false" +SERVER_PID="" + +health_ok() { + local body + body="$(curl -fsS "${BASE_URL}/health" 2>/dev/null || true)" + if [[ -z "${body}" ]]; then + return 1 + fi + "${VENV_PY}" -c 'import json,sys; obj=json.loads(sys.argv[1]); assert obj.get("status") == "UP"' \ + "${body}" >/dev/null 2>&1 +} + +if health_ok; then + echo "[smoke] CASTS already running at ${BASE_URL}" +else + echo "[smoke] Starting CASTS at ${BASE_URL} ..." + "${VENV_PY}" -m uvicorn api.app:app --host "${HOST}" --port "${PORT}" --log-level info >"${LOG_FILE}" 2>&1 & + SERVER_PID="$!" + disown "${SERVER_PID}" >/dev/null 2>&1 || true + STARTED_HERE="true" + + deadline=$(( $(date +%s) + TIMEOUT_SECONDS )) + while true; do + if health_ok; then + break + fi + if [[ $(date +%s) -ge ${deadline} ]]; then + echo "[smoke] ERROR: CASTS did not become healthy within ${TIMEOUT_SECONDS}s" + echo "[smoke] Log: ${LOG_FILE}" + tail -n 200 "${LOG_FILE}" || true + exit 1 + fi + sleep 0.2 + done +fi + +echo "[smoke] Health OK" + +post_json() { + local path="$1" + local body="$2" + local expected_code="$3" + + local tmp + tmp="$(mktemp)" + local code + code="$(curl -sS -o "${tmp}" -w "%{http_code}" -X POST "${BASE_URL}${path}" \ + -H "Content-Type: application/json" \ + -d "${body}" || true)" + local resp + resp="$(cat "${tmp}")" + rm -f "${tmp}" + + if [[ "${code}" != "${expected_code}" ]]; then + echo "[smoke] ERROR: POST ${path} expected HTTP ${expected_code}, got ${code}" + echo "[smoke] Response:" + echo "${resp}" + exit 1 + fi + + printf "%s" "${resp}" +} + +echo "[smoke] Test: empty scope rejected (expect 422)" +resp="$(post_json "/casts/decision" '{ + "api_version":"v1", + "scope":{}, + "trace":{}, + "payload":{ + "goal":"find friends", + "traversal":{"structural_signature":"V()","step_index":0}, + "node":{"label":"Person","properties":{"type":"Person","name":"Alice"}}, + "graph_schema":{ + "schema_fingerprint":"fp_smoke", + "valid_outgoing_labels":["friend"], + "valid_incoming_labels":[] + } + } +}' "422")" +printf "%s" "${resp}" | "${VENV_PY}" -c 'import json,sys; obj=json.load(sys.stdin); assert obj.get("ok") is False' +echo "[smoke] OK" + +echo "[smoke] Test: /casts/decision" +resp="$(post_json "/casts/decision" '{ + "api_version":"v1", + "scope":{"run_id":"run_smoke"}, + "trace":{}, + "payload":{ + "goal":"find friends", + "traversal":{"structural_signature":"V()","step_index":0}, + "node":{"label":"Person","properties":{"type":"Person","name":"Alice"}}, + "graph_schema":{ + "schema_fingerprint":"fp_smoke", + "valid_outgoing_labels":["friend"], + "valid_incoming_labels":[] + } + } +}' "200")" +printf "%s" "${resp}" | "${VENV_PY}" -c \ + 'import json,sys; obj=json.load(sys.stdin); assert obj.get("ok") is True; p=obj.get("payload") or {}; d=p.get("decision") or ""; assert isinstance(d,str) and d.strip()' +echo "[smoke] OK" + +echo +echo "[smoke] SMOKE OK: casts service" +echo "[smoke] Base URL: ${BASE_URL}" +if [[ "${STARTED_HERE}" == "true" ]]; then + echo "[smoke] Log: ${LOG_FILE}" +fi + diff --git a/geaflow-ai/plugins/casts/services/embedding.py b/geaflow-ai/plugins/casts/services/embedding.py index fd7a95746..a6eec928f 100644 --- a/geaflow-ai/plugins/casts/services/embedding.py +++ b/geaflow-ai/plugins/casts/services/embedding.py @@ -34,7 +34,7 @@ class EmbeddingService: def __init__(self, config: Configuration): """Initialize embedding service with configuration. - + Args: config: Configuration object containing API settings """ @@ -58,9 +58,7 @@ def __init__(self, config: Configuration): if not model: missing.append("EMBEDDING_MODEL_NAME") if missing: - raise ValueError( - "Missing required embedding configuration: " + ", ".join(missing) - ) + raise ValueError("Missing required embedding configuration: " + ", ".join(missing)) self.client = AsyncOpenAI(api_key=api_key, base_url=endpoint) self.model = model @@ -69,10 +67,10 @@ def __init__(self, config: Configuration): async def embed_text(self, text: str) -> np.ndarray: """ Generate embedding vector for a text string. - + Args: text: Input text to embed - + Returns: Normalized numpy array of embedding vector """ @@ -85,10 +83,10 @@ async def embed_text(self, text: str) -> np.ndarray: async def embed_properties(self, properties: JsonDict) -> np.ndarray: """ Generate embedding vector for a dictionary of properties. - + Args: properties: Property dictionary (identity fields will be filtered out) - + Returns: Normalized numpy array of embedding vector """ diff --git a/geaflow-ai/plugins/casts/services/llm_oracle.py b/geaflow-ai/plugins/casts/services/llm_oracle.py index 200d85d3a..e37a6da8b 100644 --- a/geaflow-ai/plugins/casts/services/llm_oracle.py +++ b/geaflow-ai/plugins/casts/services/llm_oracle.py @@ -26,17 +26,16 @@ from core.config import DefaultConfiguration from core.gremlin_state import GremlinStateMachine -from core.interfaces import Configuration, GraphSchema +from core.interfaces import Configuration, EmbeddingServiceProtocol, GraphSchema from core.models import Context, StrategyKnowledgeUnit from core.types import JsonDict -from services.embedding import EmbeddingService from utils.helpers import parse_jsons class LLMOracle: """Real LLM Oracle using OpenRouter API for generating traversal strategies.""" - def __init__(self, embed_service: EmbeddingService, config: Configuration): + def __init__(self, embed_service: EmbeddingServiceProtocol, config: Configuration): """Initialize LLM Oracle with configuration. Args: @@ -171,7 +170,7 @@ async def generate_sku(self, context: Context, schema: GraphSchema) -> StrategyK predicate=lambda x: True, goal_template=context.goal, decision_template="stop", - schema_fingerprint="schema_v1", + schema_fingerprint=self.config.get_str("CACHE_SCHEMA_FINGERPRINT", "schema_v1"), property_vector=property_vector, confidence_score=1.0, logic_complexity=1, @@ -207,12 +206,8 @@ def _format_list(values: list[str], max_items: int = 12) -> str: node_type = str(safe_properties.get("type") or context.properties.get("type") or "") node_schema = schema.get_node_schema(node_type) if node_type else {} - outgoing_labels = ( - schema.get_valid_outgoing_edge_labels(node_type) if node_type else [] - ) - incoming_labels = ( - schema.get_valid_incoming_edge_labels(node_type) if node_type else [] - ) + outgoing_labels = schema.get_valid_outgoing_edge_labels(node_type) if node_type else [] + incoming_labels = schema.get_valid_incoming_edge_labels(node_type) if node_type else [] max_depth = self.config.get_int("SIMULATION_MAX_DEPTH") current_depth = len( @@ -348,14 +343,22 @@ def _format_list(values: list[str], max_items: int = 12) -> str: def _default_predicate(_: JsonDict) -> bool: return True + predicate = _default_predicate + allow_eval = False try: - predicate_code = result.get("predicate", "lambda x: True") - predicate = eval(predicate_code) - if not callable(predicate): - predicate = _default_predicate - _ = predicate(safe_properties) # Test call + allow_eval = self.config.get_bool("LLM_ORACLE_ENABLE_PREDICATE_EVAL", False) except Exception: - predicate = _default_predicate + allow_eval = False + + if allow_eval: + try: + predicate_code = result.get("predicate", "lambda x: True") + predicate_candidate = eval(predicate_code, {"__builtins__": {}}, {}) + if callable(predicate_candidate): + _ = predicate_candidate(safe_properties) # Test call + predicate = predicate_candidate + except Exception: + predicate = _default_predicate property_vector = await self.embed_service.embed_properties(safe_properties) sigma_val = result.get("sigma_logic", 1) @@ -369,7 +372,7 @@ def _default_predicate(_: JsonDict) -> bool: goal_template=context.goal, property_vector=property_vector, decision_template=decision, - schema_fingerprint="schema_v1", + schema_fingerprint=self.config.get_str("CACHE_SCHEMA_FINGERPRINT", "schema_v1"), confidence_score=1.0, # Start with high confidence logic_complexity=sigma_val, ) @@ -391,7 +394,7 @@ def _default_predicate(_: JsonDict) -> bool: predicate=lambda x: True, goal_template=context.goal, decision_template="stop", - schema_fingerprint="schema_v1", + schema_fingerprint=self.config.get_str("CACHE_SCHEMA_FINGERPRINT", "schema_v1"), property_vector=property_vector, confidence_score=1.0, logic_complexity=1, @@ -458,9 +461,7 @@ async def recommend_starting_node_types( ) if not self.client: - self._write_debug( - "LLM client not available, falling back to all node types" - ) + self._write_debug("LLM client not available, falling back to all node types") # Fallback: return all types if LLM unavailable return node_types_list[:max_recommendations] @@ -494,8 +495,7 @@ async def recommend_starting_node_types( if isinstance(result, list): # Filter to only valid node types and limit to max recommended = [ - nt for nt in result - if isinstance(nt, str) and nt in available_node_types + nt for nt in result if isinstance(nt, str) and nt in available_node_types ][:max_recommendations] self._write_debug( diff --git a/geaflow-ai/plugins/casts/tests/conftest.py b/geaflow-ai/plugins/casts/tests/conftest.py index d484f139e..7995c4426 100644 --- a/geaflow-ai/plugins/casts/tests/conftest.py +++ b/geaflow-ai/plugins/casts/tests/conftest.py @@ -30,6 +30,12 @@ if str(module_root_parent) not in sys.path: sys.path.insert(0, str(module_root_parent)) +# Ensure the CASTS plugin root is importable so tests can import `core.*` and +# `harness.*` directly (pytest sets rootdir to the plugin directory). +plugin_root = Path(__file__).resolve().parents[1] +if str(plugin_root) not in sys.path: + sys.path.insert(0, str(plugin_root)) + def _ensure_env() -> None: os.environ.setdefault("EMBEDDING_ENDPOINT", "http://localhost") diff --git a/geaflow-ai/plugins/casts/tests/test_api.py b/geaflow-ai/plugins/casts/tests/test_api.py new file mode 100644 index 000000000..1f4e68df0 --- /dev/null +++ b/geaflow-ai/plugins/casts/tests/test_api.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from fastapi.testclient import TestClient + +from api.app import app + + +def _base_payload() -> dict: + return { + "goal": "find friends", + "traversal": {"structural_signature": "V()", "step_index": 0}, + "node": {"label": "Person", "properties": {"type": "Person", "name": "Alice"}}, + "graph_schema": { + "schema_fingerprint": "fp_test", + "valid_outgoing_labels": ["friend"], + "valid_incoming_labels": [], + }, + } + + +def test_health_ok() -> None: + client = TestClient(app) + res = client.get("/health") + assert res.status_code == 200 + assert res.json() == {"status": "UP"} + + +def test_empty_scope_rejected() -> None: + client = TestClient(app) + res = client.post( + "/casts/decision", + json={"api_version": "v1", "scope": {}, "trace": {}, "payload": _base_payload()}, + ) + assert res.status_code == 422 + body = res.json() + assert body["ok"] is False + + +def test_missing_run_id_downgrades_to_stop() -> None: + client = TestClient(app) + res = client.post( + "/casts/decision", + json={ + "api_version": "v1", + "scope": {"user_id": "u1"}, + "trace": {}, + "payload": _base_payload(), + }, + ) + assert res.status_code == 200 + body = res.json() + assert body["ok"] is True + assert body["payload"]["decision"] == "stop" + assert body["payload"]["match_type"] == "STOP_INVALID" + + +def test_with_run_id_returns_a_decision() -> None: + client = TestClient(app) + res = client.post( + "/casts/decision", + json={ + "api_version": "v1", + "scope": {"run_id": "run_test"}, + "trace": {}, + "payload": _base_payload(), + }, + ) + assert res.status_code == 200 + body = res.json() + assert body["ok"] is True + decision = body["payload"]["decision"] + assert isinstance(decision, str) + assert decision.strip() diff --git a/geaflow-ai/plugins/casts/tests/test_execution_lifecycle.py b/geaflow-ai/plugins/casts/tests/test_execution_lifecycle.py index 3d19e962a..4b5945e8e 100644 --- a/geaflow-ai/plugins/casts/tests/test_execution_lifecycle.py +++ b/geaflow-ai/plugins/casts/tests/test_execution_lifecycle.py @@ -20,8 +20,8 @@ from unittest.mock import Mock from core.config import DefaultConfiguration -from simulation.engine import SimulationEngine -from simulation.metrics import MetricsCollector +from harness.simulation.engine import SimulationEngine +from harness.simulation.metrics import MetricsCollector class MockSKU: @@ -45,10 +45,7 @@ def setup_method(self): self.mock_graph.get_schema.return_value = Mock() self.engine = SimulationEngine( - graph=self.mock_graph, - strategy_cache=Mock(), - llm_oracle=self.llm_oracle, - verbose=False + graph=self.mock_graph, strategy_cache=Mock(), llm_oracle=self.llm_oracle, verbose=False ) def test_none_mode_skips_all_validation(self): @@ -75,9 +72,7 @@ def test_none_mode_skips_all_validation(self): ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should always return (True, True) in NONE mode assert should_execute is True @@ -109,9 +104,7 @@ def test_punish_mode_continues_with_penalty(self): ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should continue but signal failure for penalty assert should_execute is True @@ -143,9 +136,7 @@ def test_stop_mode_terminates_path(self): ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should terminate and signal failure assert should_execute is False @@ -176,9 +167,7 @@ def test_low_revisit_ratio_passes(self): ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should pass all checks (0% revisit < 50% threshold) assert should_execute is True @@ -208,9 +197,7 @@ def test_simple_path_skips_cycle_detection(self): ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) assert should_execute is True assert success is True @@ -240,9 +227,7 @@ def test_confidence_threshold_stop_mode(self): # SKU with confidence below threshold sku = MockSKU(confidence_score=0.1) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should terminate due to low confidence assert should_execute is False @@ -273,9 +258,7 @@ def test_confidence_threshold_punish_mode(self): # SKU with confidence below threshold sku = MockSKU(confidence_score=0.1) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should continue but penalize assert should_execute is True @@ -287,9 +270,7 @@ def test_no_sku_passes_validation(self): metrics = MetricsCollector() request_id = metrics.initialize_path(0, "node1", {}, "goal", "rubric") - should_execute, success = self.engine.execute_prechecker( - None, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(None, request_id, metrics) # None SKU should always pass assert should_execute is True @@ -302,7 +283,9 @@ def test_nonexistent_request_id_passes(self): sku = MockSKU(confidence_score=0.5) should_execute, success = self.engine.execute_prechecker( - sku, 999, metrics # Non-existent request ID + sku, + 999, + metrics, # Non-existent request ID ) # Should pass since path doesn't exist @@ -347,9 +330,7 @@ def test_cycle_detection_threshold_boundary(self): ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should pass at exactly threshold (not greater than) assert should_execute is True @@ -382,9 +363,7 @@ def test_cycle_detection_just_above_threshold(self): ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should fail cycle detection assert should_execute is False @@ -405,10 +384,7 @@ def setup_method(self): self.mock_graph.get_schema.return_value = Mock() self.engine = SimulationEngine( - graph=self.mock_graph, - strategy_cache=Mock(), - llm_oracle=self.llm_oracle, - verbose=False + graph=self.mock_graph, strategy_cache=Mock(), llm_oracle=self.llm_oracle, verbose=False ) def test_postchecker_always_returns_true(self): @@ -418,9 +394,7 @@ def test_postchecker_always_returns_true(self): sku = MockSKU() execution_result = ["node2", "node3"] - result = self.engine.execute_postchecker( - sku, request_id, metrics, execution_result - ) + result = self.engine.execute_postchecker(sku, request_id, metrics, execution_result) assert result is True @@ -430,9 +404,7 @@ def test_postchecker_with_none_sku(self): request_id = metrics.initialize_path(0, "node1", {}, "goal", "rubric") execution_result = [] - result = self.engine.execute_postchecker( - None, request_id, metrics, execution_result - ) + result = self.engine.execute_postchecker(None, request_id, metrics, execution_result) assert result is True @@ -442,9 +414,7 @@ def test_postchecker_with_empty_result(self): request_id = metrics.initialize_path(0, "node1", {}, "goal", "rubric") sku = MockSKU() - result = self.engine.execute_postchecker( - sku, request_id, metrics, [] - ) + result = self.engine.execute_postchecker(sku, request_id, metrics, []) assert result is True @@ -463,10 +433,7 @@ def setup_method(self): self.mock_graph.get_schema.return_value = Mock() self.engine = SimulationEngine( - graph=self.mock_graph, - strategy_cache=Mock(), - llm_oracle=self.llm_oracle, - verbose=False + graph=self.mock_graph, strategy_cache=Mock(), llm_oracle=self.llm_oracle, verbose=False ) def test_mode_none_case_insensitive(self): @@ -478,14 +445,22 @@ def test_mode_none_case_insensitive(self): # Add cyclic steps for i in range(5): metrics.record_path_step( - request_id, i, "node1", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", f"d{i}" + request_id, + i, + "node1", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + f"d{i}", ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # NONE mode should skip validation even with lowercase assert should_execute is True @@ -519,9 +494,7 @@ def test_mode_punish_case_variants(self): ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # All variants should work consistently assert should_execute is True @@ -542,10 +515,7 @@ def setup_method(self): self.mock_graph.get_schema.return_value = Mock() self.engine = SimulationEngine( - graph=self.mock_graph, - strategy_cache=Mock(), - llm_oracle=self.llm_oracle, - verbose=False + graph=self.mock_graph, strategy_cache=Mock(), llm_oracle=self.llm_oracle, verbose=False ) def test_cycle_detection_threshold_default(self): @@ -588,9 +558,7 @@ def test_custom_threshold_values(self): ) sku = MockSKU(confidence_score=0.6) # Above 0.5 min confidence - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should fail cycle detection but pass confidence check assert should_execute is True # PUNISH mode continues diff --git a/geaflow-ai/plugins/casts/tests/test_gremlin_step_state_machine.py b/geaflow-ai/plugins/casts/tests/test_gremlin_step_state_machine.py index 64f39904a..54210543c 100644 --- a/geaflow-ai/plugins/casts/tests/test_gremlin_step_state_machine.py +++ b/geaflow-ai/plugins/casts/tests/test_gremlin_step_state_machine.py @@ -70,6 +70,7 @@ - **Idea**: Enforce strict syntax. - **Verify**: `V().outV()` must lead to `END` with no options. """ + import unittest from core.gremlin_state import GremlinStateMachine @@ -82,22 +83,22 @@ class TestGraphSchema(unittest.TestCase): def setUp(self): """Set up a mock graph schema for testing.""" nodes = { - 'A': {'id': 'A', 'type': 'Person'}, - 'B': {'id': 'B', 'type': 'Person'}, - 'C': {'id': 'C', 'type': 'Company'}, - 'D': {'id': 'D', 'type': 'Person'}, # Node with only incoming edges + "A": {"id": "A", "type": "Person"}, + "B": {"id": "B", "type": "Person"}, + "C": {"id": "C", "type": "Company"}, + "D": {"id": "D", "type": "Person"}, # Node with only incoming edges } edges = { - 'A': [ - {'label': 'friend', 'target': 'B'}, - {'label': 'works_for', 'target': 'C'}, + "A": [ + {"label": "friend", "target": "B"}, + {"label": "works_for", "target": "C"}, ], - 'B': [ - {'label': 'friend', 'target': 'A'}, + "B": [ + {"label": "friend", "target": "A"}, ], - 'C': [ - {'label': 'employs', 'target': 'A'}, - {'label': 'partner', 'target': 'D'}, + "C": [ + {"label": "employs", "target": "A"}, + {"label": "partner", "target": "D"}, ], } self.schema = InMemoryGraphSchema(nodes, edges) @@ -105,59 +106,54 @@ def setUp(self): def test_get_valid_outgoing_edge_labels(self): """Test that get_valid_outgoing_edge_labels returns correct outgoing labels.""" self.assertCountEqual( - self.schema.get_valid_outgoing_edge_labels('Person'), ['friend', 'works_for'] + self.schema.get_valid_outgoing_edge_labels("Person"), ["friend", "works_for"] ) self.assertCountEqual( - self.schema.get_valid_outgoing_edge_labels('Company'), ['employs', 'partner'] + self.schema.get_valid_outgoing_edge_labels("Company"), ["employs", "partner"] ) def test_get_valid_outgoing_edge_labels_no_outgoing(self): """Test get_valid_outgoing_edge_labels returns empty list with no outgoing edges.""" - self.assertEqual(self.schema.get_valid_outgoing_edge_labels('Unknown'), []) + self.assertEqual(self.schema.get_valid_outgoing_edge_labels("Unknown"), []) def test_get_valid_incoming_edge_labels(self): """Test that get_valid_incoming_edge_labels returns correct incoming labels.""" self.assertCountEqual( - self.schema.get_valid_incoming_edge_labels('Person'), - ['employs', 'friend', 'partner'], - ) - self.assertCountEqual( - self.schema.get_valid_incoming_edge_labels('Company'), ['works_for'] + self.schema.get_valid_incoming_edge_labels("Person"), + ["employs", "friend", "partner"], ) + self.assertCountEqual(self.schema.get_valid_incoming_edge_labels("Company"), ["works_for"]) def test_get_valid_incoming_edge_labels_no_incoming(self): """Test get_valid_incoming_edge_labels returns empty list with no incoming edges.""" - self.assertEqual(self.schema.get_valid_incoming_edge_labels('Unknown'), []) + self.assertEqual(self.schema.get_valid_incoming_edge_labels("Unknown"), []) class TestGremlinStateMachine(unittest.TestCase): - def setUp(self): """Set up a mock graph schema for testing the state machine.""" nodes = { - 'A': {'id': 'A', 'type': 'Person'}, - 'B': {'id': 'B', 'type': 'Person'}, - 'C': {'id': 'C', 'type': 'Company'}, + "A": {"id": "A", "type": "Person"}, + "B": {"id": "B", "type": "Person"}, + "C": {"id": "C", "type": "Company"}, } edges = { - 'A': [ - {'label': 'friend', 'target': 'B'}, - {'label': 'knows', 'target': 'B'}, + "A": [ + {"label": "friend", "target": "B"}, + {"label": "knows", "target": "B"}, ], - 'B': [ - {'label': 'friend', 'target': 'A'}, + "B": [ + {"label": "friend", "target": "A"}, ], - 'C': [ - {'label': 'employs', 'target': 'A'}, + "C": [ + {"label": "employs", "target": "A"}, ], } self.schema = InMemoryGraphSchema(nodes, edges) def test_vertex_state_options(self): """Test that the state machine generates correct, concrete options from a vertex state.""" - state, options = GremlinStateMachine.get_state_and_options( - "V()", self.schema, "Person" - ) + state, options = GremlinStateMachine.get_state_and_options("V()", self.schema, "Person") self.assertEqual(state, "V") # Check for concrete 'out' steps @@ -178,9 +174,7 @@ def test_vertex_state_options(self): def test_empty_labels(self): """Test that no label-based steps are generated if there are no corresponding edges.""" - state, options = GremlinStateMachine.get_state_and_options( - "V()", self.schema, "Company" - ) + state, options = GremlinStateMachine.get_state_and_options("V()", self.schema, "Company") self.assertEqual(state, "V") # Company has outgoing 'employs'/'partner' edges but no incoming edges in this setup. self.assertIn("out('employs')", options) diff --git a/geaflow-ai/plugins/casts/tests/test_lifecycle_integration.py b/geaflow-ai/plugins/casts/tests/test_lifecycle_integration.py index 3c74a6dd1..91d9b009c 100644 --- a/geaflow-ai/plugins/casts/tests/test_lifecycle_integration.py +++ b/geaflow-ai/plugins/casts/tests/test_lifecycle_integration.py @@ -20,8 +20,8 @@ from unittest.mock import Mock from core.config import DefaultConfiguration -from simulation.engine import SimulationEngine -from simulation.metrics import MetricsCollector +from harness.simulation.engine import SimulationEngine +from harness.simulation.metrics import MetricsCollector class MockSKU: @@ -41,10 +41,7 @@ def __init__(self): def update_confidence(self, sku, success): """Record confidence updates.""" - self.confidence_updates.append({ - "sku": sku, - "success": success - }) + self.confidence_updates.append({"sku": sku, "success": success}) class TestLifecycleIntegration: @@ -65,7 +62,7 @@ def setup_method(self): graph=self.mock_graph, strategy_cache=self.strategy_cache, llm_oracle=self.llm_oracle, - verbose=False + verbose=False, ) def test_complete_lifecycle_with_passing_precheck(self): @@ -77,16 +74,24 @@ def test_complete_lifecycle_with_passing_precheck(self): # Add a step with low revisit metrics.record_path_step( - request_id, 0, "node1", None, None, None, "sig1", "goal", {}, - "Tier1", "sku1", "out('friend')" + request_id, + 0, + "node1", + None, + None, + None, + "sig1", + "goal", + {}, + "Tier1", + "sku1", + "out('friend')", ) sku = MockSKU(confidence_score=0.5) # Phase 1: Precheck - should_execute, precheck_success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, precheck_success = self.engine.execute_prechecker(sku, request_id, metrics) assert should_execute is True assert precheck_success is True @@ -114,16 +119,24 @@ def test_complete_lifecycle_with_failing_precheck_stop_mode(self): # Create high revisit ratio for i in range(10): metrics.record_path_step( - request_id, i, "node1", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", "out('friend')" + request_id, + i, + "node1", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + "out('friend')", ) sku = MockSKU(confidence_score=0.5) # Phase 1: Precheck - should_execute, precheck_success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, precheck_success = self.engine.execute_prechecker(sku, request_id, metrics) assert should_execute is False assert precheck_success is False @@ -140,16 +153,24 @@ def test_complete_lifecycle_with_failing_precheck_punish_mode(self): # Create high revisit ratio for i in range(10): metrics.record_path_step( - request_id, i, "node1", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", "out('friend')" + request_id, + i, + "node1", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + "out('friend')", ) sku = MockSKU(confidence_score=0.5) # Phase 1: Precheck - should_execute, precheck_success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, precheck_success = self.engine.execute_prechecker(sku, request_id, metrics) assert should_execute is True # Continue execution assert precheck_success is False # But signal failure @@ -174,8 +195,18 @@ def test_rollback_integration_with_precheck_failure(self): # Add steps leading to cycle for i in range(10): metrics.record_path_step( - request_id, i, "node1", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", "out('friend')" + request_id, + i, + "node1", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + "out('friend')", ) initial_step_count = len(metrics.paths[request_id]["steps"]) @@ -184,9 +215,7 @@ def test_rollback_integration_with_precheck_failure(self): sku = MockSKU(confidence_score=0.5) # Precheck fails - should_execute, _ = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, _ = self.engine.execute_prechecker(sku, request_id, metrics) if not should_execute: # Simulate rollback as done in real code @@ -202,9 +231,7 @@ def test_lifecycle_with_none_sku(self): request_id = metrics.initialize_path(0, "node1", {}, "goal", "rubric") # Phase 1: Precheck with None SKU - should_execute, precheck_success = self.engine.execute_prechecker( - None, request_id, metrics - ) + should_execute, precheck_success = self.engine.execute_prechecker(None, request_id, metrics) assert should_execute is True assert precheck_success is True @@ -228,16 +255,24 @@ def test_lifecycle_confidence_penalty_integration(self): # Add cyclic steps for i in range(5): metrics.record_path_step( - request_id, i, "node1", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", "out('friend')" + request_id, + i, + "node1", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + "out('friend')", ) sku = MockSKU(confidence_score=0.5) # Precheck fails due to cycle - should_execute, precheck_success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, precheck_success = self.engine.execute_prechecker(sku, request_id, metrics) # Should continue but penalize assert should_execute is True @@ -261,16 +296,24 @@ def test_lifecycle_multiple_validation_failures(self): # Create both cycle and low confidence for i in range(10): metrics.record_path_step( - request_id, i, "node1", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", "out('friend')" + request_id, + i, + "node1", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + "out('friend')", ) sku = MockSKU(confidence_score=0.2) # Below threshold # Precheck should fail on first condition met - should_execute, precheck_success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, precheck_success = self.engine.execute_prechecker(sku, request_id, metrics) # Should terminate (STOP mode) assert should_execute is False @@ -285,16 +328,24 @@ def test_lifecycle_none_mode_bypasses_all_checks(self): # Create worst-case scenario: high cycles + low confidence for i in range(20): metrics.record_path_step( - request_id, i, "node1", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", "out('friend')" + request_id, + i, + "node1", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + "out('friend')", ) sku = MockSKU(confidence_score=0.01) # Extremely low # Precheck should still pass in NONE mode - should_execute, precheck_success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, precheck_success = self.engine.execute_prechecker(sku, request_id, metrics) assert should_execute is True assert precheck_success is True @@ -308,9 +359,7 @@ def test_lifecycle_with_empty_path(self): sku = MockSKU(confidence_score=0.5) # Precheck on empty path - should_execute, precheck_success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, precheck_success = self.engine.execute_prechecker(sku, request_id, metrics) # Should pass (no cycle possible with empty path) assert should_execute is True @@ -326,22 +375,28 @@ def test_lifecycle_preserves_path_state(self): # Add steps for i in range(5): metrics.record_path_step( - request_id, i, f"node{i}", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", "out('friend')" + request_id, + i, + f"node{i}", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + "out('friend')", ) - initial_steps = [ - step.copy() for step in metrics.paths[request_id]["steps"] - ] + initial_steps = [step.copy() for step in metrics.paths[request_id]["steps"]] sku = MockSKU(confidence_score=0.5) # Run precheck self.engine.execute_prechecker(sku, request_id, metrics) # Run postcheck - self.engine.execute_postchecker( - sku, request_id, metrics, ["node6"] - ) + self.engine.execute_postchecker(sku, request_id, metrics, ["node6"]) # Verify path state unchanged assert len(metrics.paths[request_id]["steps"]) == len(initial_steps) @@ -363,10 +418,7 @@ def setup_method(self): self.mock_graph.get_schema.return_value = Mock() self.engine = SimulationEngine( - graph=self.mock_graph, - strategy_cache=Mock(), - llm_oracle=self.llm_oracle, - verbose=False + graph=self.mock_graph, strategy_cache=Mock(), llm_oracle=self.llm_oracle, verbose=False ) def test_lifecycle_with_single_step_path(self): @@ -378,14 +430,22 @@ def test_lifecycle_with_single_step_path(self): # Single step - cannot have cycle metrics.record_path_step( - request_id, 0, "node1", None, None, None, "sig1", "goal", {}, - "Tier1", "sku1", "out('friend')" + request_id, + 0, + "node1", + None, + None, + None, + "sig1", + "goal", + {}, + "Tier1", + "sku1", + "out('friend')", ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Single step should pass (cycle detection requires >= 2 steps) assert should_execute is True @@ -403,26 +463,42 @@ def test_lifecycle_alternating_pass_fail(self): # Start with low revisit (pass) for i in range(3): metrics.record_path_step( - request_id, i, f"node{i}", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", "out('friend')" + request_id, + i, + f"node{i}", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + "out('friend')", ) sku = MockSKU(confidence_score=0.5) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) results.append(("pass", should_execute, success)) # Add cycles (fail) - all same node for i in range(7): metrics.record_path_step( - request_id, 3 + i, "node1", None, None, None, f"sig{3+i}", - "goal", {}, "Tier1", f"sku{3+i}", "out('friend')" + request_id, + 3 + i, + "node1", + None, + None, + None, + f"sig{3 + i}", + "goal", + {}, + "Tier1", + f"sku{3 + i}", + "out('friend')", ) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) results.append(("fail", should_execute, success)) # Verify pattern: first passes (0% revisit), second fails (high revisit) @@ -437,14 +513,22 @@ def test_lifecycle_with_zero_confidence(self): request_id = metrics.initialize_path(0, "node1", {}, "goal", "rubric") metrics.record_path_step( - request_id, 0, "node1", None, None, None, "sig", "goal", {}, - "Tier1", "sku1", "out('friend')" + request_id, + 0, + "node1", + None, + None, + None, + "sig", + "goal", + {}, + "Tier1", + "sku1", + "out('friend')", ) sku = MockSKU(confidence_score=0.0) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should fail due to confidence < 0.1 assert should_execute is False @@ -458,14 +542,22 @@ def test_lifecycle_with_perfect_confidence(self): request_id = metrics.initialize_path(0, "node1", {}, "goal", "rubric") metrics.record_path_step( - request_id, 0, "node1", None, None, None, "sig", "goal", {}, - "Tier1", "sku1", "out('friend')" + request_id, + 0, + "node1", + None, + None, + None, + "sig", + "goal", + {}, + "Tier1", + "sku1", + "out('friend')", ) sku = MockSKU(confidence_score=1.0) - should_execute, success = self.engine.execute_prechecker( - sku, request_id, metrics - ) + should_execute, success = self.engine.execute_prechecker(sku, request_id, metrics) # Should pass all checks assert should_execute is True diff --git a/geaflow-ai/plugins/casts/tests/test_metrics_collector.py b/geaflow-ai/plugins/casts/tests/test_metrics_collector.py index 0f8f5a553..199fad8ee 100644 --- a/geaflow-ai/plugins/casts/tests/test_metrics_collector.py +++ b/geaflow-ai/plugins/casts/tests/test_metrics_collector.py @@ -17,7 +17,7 @@ """Unit tests for MetricsCollector class.""" -from simulation.metrics import MetricsCollector +from harness.simulation.metrics import MetricsCollector class TestMetricsCollector: @@ -53,7 +53,7 @@ def test_record_path_step(self): properties={"name": "Alice"}, match_type="Tier1", sku_id="sku1", - decision="out('knows')" + decision="out('knows')", ) steps = metrics.paths[request_id]["steps"] @@ -152,8 +152,18 @@ def test_rollback_multiple_times(self): # Add 5 steps for i in range(5): metrics.record_path_step( - request_id, i, f"node{i}", None, None, None, f"sig{i}", - "goal", {}, "Tier1", f"sku{i}", f"d{i}" + request_id, + i, + f"node{i}", + None, + None, + None, + f"sig{i}", + "goal", + {}, + "Tier1", + f"sku{i}", + f"d{i}", ) assert len(metrics.paths[request_id]["steps"]) == 5 diff --git a/geaflow-ai/plugins/casts/tests/test_signature_abstraction.py b/geaflow-ai/plugins/casts/tests/test_signature_abstraction.py index 17d4c5da5..88c56f7d5 100644 --- a/geaflow-ai/plugins/casts/tests/test_signature_abstraction.py +++ b/geaflow-ai/plugins/casts/tests/test_signature_abstraction.py @@ -37,7 +37,7 @@ from core.interfaces import DataSource, GraphSchema from core.models import Context, StrategyKnowledgeUnit from core.strategy_cache import StrategyCache -from simulation.executor import TraversalExecutor +from harness.simulation.executor import TraversalExecutor class MockGraphSchema(GraphSchema): @@ -135,9 +135,7 @@ async def test_edge_traversal_preserves_labels(self): decision = "out('friend')" current_node_id = "A" - result = await self.executor.execute_decision( - current_node_id, decision, current_signature - ) + result = await self.executor.execute_decision(current_node_id, decision, current_signature) # Verify edge labels are preserved in the signature. self.assertEqual(len(result), 1) @@ -151,9 +149,7 @@ async def test_filter_step_preserves_full_details(self): decision = "has('type','Person')" current_node_id = "A" - result = await self.executor.execute_decision( - current_node_id, decision, current_signature - ) + result = await self.executor.execute_decision(current_node_id, decision, current_signature) # Verify has() parameters are preserved. if result: # has() may not match and can return an empty list. @@ -166,9 +162,7 @@ async def test_edge_step_with_outE(self): decision = "outE('transfer')" current_node_id = "B" - result = await self.executor.execute_decision( - current_node_id, decision, current_signature - ) + result = await self.executor.execute_decision(current_node_id, decision, current_signature) self.assertEqual(len(result), 1) next_node_id, next_signature, traversed_edge = result[0] @@ -180,9 +174,7 @@ async def test_dedup_step_canonical_form(self): decision = "dedup()" current_node_id = "A" - result = await self.executor.execute_decision( - current_node_id, decision, current_signature - ) + result = await self.executor.execute_decision(current_node_id, decision, current_signature) # dedup should be retained in the signature. self.assertEqual(len(result), 1) @@ -201,11 +193,9 @@ def _create_cache_with_level(self, level: int, edge_whitelist=None): """Create a StrategyCache with the specified abstraction level.""" config = MagicMock() config.get_float = MagicMock( - side_effect=lambda k, d=0.0: 2.0 - if "THRESHOLD" in k - else 0.1 - if k == "MIN_EXECUTION_CONFIDENCE" - else d + side_effect=lambda k, d=0.0: ( + 2.0 if "THRESHOLD" in k else 0.1 if k == "MIN_EXECUTION_CONFIDENCE" else d + ) ) config.get_str = MagicMock(return_value="schema_v2_canonical") config.get_int = MagicMock( diff --git a/geaflow-ai/plugins/casts/tests/test_simple_path.py b/geaflow-ai/plugins/casts/tests/test_simple_path.py index 6ab624572..94216e59f 100644 --- a/geaflow-ai/plugins/casts/tests/test_simple_path.py +++ b/geaflow-ai/plugins/casts/tests/test_simple_path.py @@ -109,6 +109,7 @@ class TestSimplePathExecution: @pytest.fixture def mock_graph(self): """Create a simple mock graph for testing.""" + # Create a simple graph: A -> B -> C -> A (triangle) class MockGraph: def __init__(self): @@ -128,6 +129,7 @@ def __init__(self): @pytest.fixture def mock_schema(self): """Create a mock schema.""" + class MockSchema: def get_valid_outgoing_edge_labels(self, node_type): return ["friend"] @@ -139,7 +141,7 @@ def get_valid_incoming_edge_labels(self, node_type): async def test_simple_path_step_execution(self, mock_graph, mock_schema): """Test that simplePath() step passes through current node.""" - from simulation.executor import TraversalExecutor + from harness.simulation.executor import TraversalExecutor executor = TraversalExecutor(mock_graph, mock_schema) @@ -158,7 +160,7 @@ async def test_simple_path_step_execution(self, mock_graph, mock_schema): async def test_simple_path_filtering(self, mock_graph, mock_schema): """Test that simplePath filters out visited nodes.""" - from simulation.executor import TraversalExecutor + from harness.simulation.executor import TraversalExecutor executor = TraversalExecutor(mock_graph, mock_schema) @@ -194,7 +196,7 @@ async def test_simple_path_filtering(self, mock_graph, mock_schema): async def test_without_simple_path_allows_cycles(self, mock_graph, mock_schema): """Test that without simplePath(), cycles are allowed.""" - from simulation.executor import TraversalExecutor + from harness.simulation.executor import TraversalExecutor executor = TraversalExecutor(mock_graph, mock_schema) @@ -230,7 +232,7 @@ async def test_without_simple_path_allows_cycles(self, mock_graph, mock_schema): async def test_simple_path_allows_filter_steps(self, mock_graph, mock_schema): """Test that simplePath does not block non-traversal filter steps.""" - from simulation.executor import TraversalExecutor + from harness.simulation.executor import TraversalExecutor executor = TraversalExecutor(mock_graph, mock_schema) @@ -253,7 +255,7 @@ async def test_simple_path_allows_filter_steps(self, mock_graph, mock_schema): async def test_clear_path_history(self, mock_graph, mock_schema): """Test that clear_path_history properly cleans up.""" - from simulation.executor import TraversalExecutor + from harness.simulation.executor import TraversalExecutor executor = TraversalExecutor(mock_graph, mock_schema) diff --git a/geaflow-ai/plugins/casts/tests/test_starting_node_selection.py b/geaflow-ai/plugins/casts/tests/test_starting_node_selection.py index 1f234e79a..d63ad67c0 100644 --- a/geaflow-ai/plugins/casts/tests/test_starting_node_selection.py +++ b/geaflow-ai/plugins/casts/tests/test_starting_node_selection.py @@ -22,7 +22,7 @@ import pytest from core.config import DefaultConfiguration -from data.sources import SyntheticDataSource +from harness.data.sources import SyntheticDataSource from services.embedding import EmbeddingService from services.llm_oracle import LLMOracle @@ -40,9 +40,7 @@ def mock_config(): @pytest.mark.asyncio -async def test_recommend_starting_node_types_basic( - mock_embedding_service, mock_config -): +async def test_recommend_starting_node_types_basic(mock_embedding_service, mock_config): """Test basic happy-path for recommending starting node types.""" # Arrange oracle = LLMOracle(mock_embedding_service, mock_config) @@ -50,18 +48,16 @@ async def test_recommend_starting_node_types_basic( # Mock the LLM response mock_response = MagicMock() - mock_response.choices[0].message.content = '''```json + mock_response.choices[0].message.content = """```json ["Person", "Company"] - ```''' + ```""" oracle.client.chat.completions.create.return_value = mock_response goal = "Find risky investments between people and companies." available_types = {"Person", "Company", "Loan", "Account"} # Act - recommended = await oracle.recommend_starting_node_types( - goal, available_types - ) + recommended = await oracle.recommend_starting_node_types(goal, available_types) # Assert assert isinstance(recommended, list) @@ -71,50 +67,42 @@ async def test_recommend_starting_node_types_basic( @pytest.mark.asyncio -async def test_recommend_starting_node_types_malformed_json( - mock_embedding_service, mock_config -): +async def test_recommend_starting_node_types_malformed_json(mock_embedding_service, mock_config): """Test robustness against malformed JSON from LLM.""" # Arrange oracle = LLMOracle(mock_embedding_service, mock_config) oracle.client = AsyncMock() mock_response = MagicMock() - mock_response.choices[0].message.content = '''```json + mock_response.choices[0].message.content = """```json ["Person", "Company",,] - ```''' # Extra comma + ```""" # Extra comma oracle.client.chat.completions.create.return_value = mock_response # Act - recommended = await oracle.recommend_starting_node_types( - "test goal", {"Person", "Company"} - ) + recommended = await oracle.recommend_starting_node_types("test goal", {"Person", "Company"}) # Assert - assert recommended == [] # Should fail gracefully + assert recommended == [] # Should fail gracefully @pytest.mark.asyncio -async def test_recommend_starting_node_types_with_comments( - mock_embedding_service, mock_config -): +async def test_recommend_starting_node_types_with_comments(mock_embedding_service, mock_config): """Test that parse_jsons handles comments correctly.""" # Arrange oracle = LLMOracle(mock_embedding_service, mock_config) oracle.client = AsyncMock() mock_response = MagicMock() - mock_response.choices[0].message.content = '''```json + mock_response.choices[0].message.content = """```json // Top-level comment [ "Person", // Person node type "Company" // Company node type ] - ```''' + ```""" oracle.client.chat.completions.create.return_value = mock_response # Act - recommended = await oracle.recommend_starting_node_types( - "test goal", {"Person", "Company"} - ) + recommended = await oracle.recommend_starting_node_types("test goal", {"Person", "Company"}) # Assert assert set(recommended) == {"Person", "Company"} @@ -129,15 +117,13 @@ async def test_recommend_starting_node_types_filters_invalid_types( oracle = LLMOracle(mock_embedding_service, mock_config) oracle.client = AsyncMock() mock_response = MagicMock() - mock_response.choices[0].message.content = '''```json + mock_response.choices[0].message.content = """```json ["Person", "Unicorn"] -```''' +```""" oracle.client.chat.completions.create.return_value = mock_response # Act - recommended = await oracle.recommend_starting_node_types( - "test goal", {"Person", "Company"} - ) + recommended = await oracle.recommend_starting_node_types("test goal", {"Person", "Company"}) # Assert assert recommended == ["Person"] @@ -153,13 +139,16 @@ def synthetic_data_source(): "1": {"id": "1", "type": "Person"}, "2": {"id": "2", "type": "Company"}, "3": {"id": "3", "type": "Company"}, - "4": {"id": "4", "type": "Loan"}, # Degree 0 + "4": {"id": "4", "type": "Loan"}, # Degree 0 } source._edges = { - "0": [{"target": "1", "label": "friend"}, {"target": "2", "label": "invest"}], # Degree 2 - "1": [{"target": "3", "label": "invest"}], # Degree 1 - "2": [{"target": "0", "label": "customer"}, {"target": "3", "label": "partner"}], # Degree 2 - "3": [{"target": "1", "label": "customer"}], # Degree 1 + "0": [{"target": "1", "label": "friend"}, {"target": "2", "label": "invest"}], # Degree 2 + "1": [{"target": "3", "label": "invest"}], # Degree 1 + "2": [ + {"target": "0", "label": "customer"}, + {"target": "3", "label": "partner"}, + ], # Degree 2 + "3": [{"target": "1", "label": "customer"}], # Degree 1 } return source diff --git a/geaflow-ai/plugins/casts/tests/test_threshold_calculation.py b/geaflow-ai/plugins/casts/tests/test_threshold_calculation.py index a7d9ad901..f61bc6ec4 100644 --- a/geaflow-ai/plugins/casts/tests/test_threshold_calculation.py +++ b/geaflow-ai/plugins/casts/tests/test_threshold_calculation.py @@ -55,15 +55,17 @@ def test_formula_correctness_with_doc_examples(self): sku_head = self.create_mock_sku(eta=1000, sigma=1) threshold_head = calculate_dynamic_similarity_threshold(sku_head, kappa=0.01, beta=0.1) # Expected: approx 0.998 (allow small error) - self.assertAlmostEqual(threshold_head, 0.998, places=2, - msg="Head scenario threshold should be near 0.998") + self.assertAlmostEqual( + threshold_head, 0.998, places=2, msg="Head scenario threshold should be near 0.998" + ) # Example 2: Tail scenario (eta=0.5, sigma=1, beta=0.1, kappa=0.01) sku_tail = self.create_mock_sku(eta=0.5, sigma=1) threshold_tail = calculate_dynamic_similarity_threshold(sku_tail, kappa=0.01, beta=0.1) # Expected: approx 0.99 (more permissive) - self.assertAlmostEqual(threshold_tail, 0.99, places=2, - msg="Tail scenario threshold should be near 0.99") + self.assertAlmostEqual( + threshold_tail, 0.99, places=2, msg="Tail scenario threshold should be near 0.99" + ) # Example 3: Complex logic scenario (eta=1000, sigma=5, beta=0.1, kappa=0.01) sku_complex = self.create_mock_sku(eta=1000, sigma=5) @@ -71,13 +73,13 @@ def test_formula_correctness_with_doc_examples(self): sku_complex, kappa=0.01, beta=0.1 ) # Expected: near 0.99, actual result is closer to 0.9988 - self.assertGreater(threshold_complex, 0.998, - msg="Complex-logic scenario threshold should be > 0.998") + self.assertGreater( + threshold_complex, 0.998, msg="Complex-logic scenario threshold should be > 0.998" + ) # Head scenario should be stricter than tail scenario self.assertGreater( - threshold_head, threshold_tail, - msg="High-frequency SKU should have a higher threshold" + threshold_head, threshold_tail, msg="High-frequency SKU should have a higher threshold" ) def test_monotonicity_with_confidence(self): @@ -99,7 +101,7 @@ def test_monotonicity_with_confidence(self): for i in range(1, len(thresholds)): msg = ( "Thresholds must be monotonic: " - f"eta={confidence_values[i]} should be >= eta={confidence_values[i-1]}" + f"eta={confidence_values[i]} should be >= eta={confidence_values[i - 1]}" ) self.assertGreaterEqual( thresholds[i], @@ -126,7 +128,7 @@ def test_monotonicity_with_complexity(self): for i in range(1, len(thresholds)): msg = ( "Threshold should increase with complexity: " - f"sigma={complexity_values[i]} should be >= sigma={complexity_values[i-1]}" + f"sigma={complexity_values[i]} should be >= sigma={complexity_values[i - 1]}" ) self.assertGreaterEqual( thresholds[i], @@ -172,12 +174,13 @@ def test_kappa_sensitivity(self): # As kappa increases, threshold should decrease. for i in range(1, len(thresholds)): self.assertLessEqual( - thresholds[i], thresholds[i-1], + thresholds[i], + thresholds[i - 1], msg=( "Threshold should decrease as kappa increases: " f"kappa={kappa_values[i]} -> {thresholds[i]:.4f} " - f"<= kappa={kappa_values[i-1]} -> {thresholds[i-1]:.4f}" - ) + f"<= kappa={kappa_values[i - 1]} -> {thresholds[i - 1]:.4f}" + ), ) def test_beta_sensitivity(self): @@ -199,9 +202,7 @@ def test_beta_sensitivity(self): threshold_high = calculate_dynamic_similarity_threshold( sku_high, kappa=kappa, beta=beta ) - threshold_low = calculate_dynamic_similarity_threshold( - sku_low, kappa=kappa, beta=beta - ) + threshold_low = calculate_dynamic_similarity_threshold(sku_low, kappa=kappa, beta=beta) gap = threshold_high - threshold_low threshold_gaps.append(gap) @@ -209,11 +210,12 @@ def test_beta_sensitivity(self): # As beta increases, the gap should increase. for i in range(1, len(threshold_gaps)): self.assertGreaterEqual( - threshold_gaps[i], threshold_gaps[i-1], + threshold_gaps[i], + threshold_gaps[i - 1], msg=( "Gap should increase as beta increases: " - f"beta={beta_values[i]} gap >= beta={beta_values[i-1]} gap" - ) + f"beta={beta_values[i]} gap >= beta={beta_values[i - 1]} gap" + ), ) def test_realistic_scenarios_with_current_config(self): @@ -234,17 +236,17 @@ def test_realistic_scenarios_with_current_config(self): for name, eta, sigma, (expected_min, expected_max) in test_cases: with self.subTest(scenario=name, eta=eta, sigma=sigma): sku = self.create_mock_sku(eta=eta, sigma=sigma) - threshold = calculate_dynamic_similarity_threshold( - sku, kappa=kappa, beta=beta - ) + threshold = calculate_dynamic_similarity_threshold(sku, kappa=kappa, beta=beta) self.assertGreaterEqual( - threshold, expected_min, - msg=f"{name}: threshold {threshold:.4f} should be >= {expected_min}" + threshold, + expected_min, + msg=f"{name}: threshold {threshold:.4f} should be >= {expected_min}", ) self.assertLessEqual( - threshold, expected_max, - msg=f"{name}: threshold {threshold:.4f} should be <= {expected_max}" + threshold, + expected_max, + msg=f"{name}: threshold {threshold:.4f} should be <= {expected_max}", ) def test_practical_matching_scenario(self): @@ -271,18 +273,22 @@ def test_practical_matching_scenario(self): # Old config should not match self.assertAlmostEqual( - threshold_old, 0.8915, delta=0.01, - msg=f"Old threshold should be near 0.8915, actual: {threshold_old:.4f}" + threshold_old, + 0.8915, + delta=0.01, + msg=f"Old threshold should be near 0.8915, actual: {threshold_old:.4f}", ) self.assertLess( - user_similarity, threshold_old, - msg=f"Old config should not match: {user_similarity:.4f} < {threshold_old:.4f}" + user_similarity, + threshold_old, + msg=f"Old config should not match: {user_similarity:.4f} < {threshold_old:.4f}", ) # Increasing kappa should lower the threshold self.assertLess( - threshold_new, threshold_old, - msg=f"Higher kappa should lower threshold: {threshold_new:.4f} < {threshold_old:.4f}" + threshold_new, + threshold_old, + msg=f"Higher kappa should lower threshold: {threshold_new:.4f} < {threshold_old:.4f}", ) print("\n[Scenario] SKU_17 (eta=20, sigma=2):") @@ -298,11 +304,12 @@ def test_practical_matching_scenario(self): ) self.assertLessEqual( - threshold_simple_old, user_similarity, + threshold_simple_old, + user_similarity, msg=( "Simple SKU should match under old config: " f"{threshold_simple_old:.4f} <= {user_similarity:.4f}" - ) + ), ) def test_mathematical_properties_summary(self): @@ -311,11 +318,7 @@ def test_mathematical_properties_summary(self): beta = 0.10 # Generate test points - test_points = [ - (eta, sigma) - for eta in [1, 2, 5, 10, 20, 50, 100] - for sigma in [1, 2, 3, 5] - ] + test_points = [(eta, sigma) for eta in [1, 2, 5, 10, 20, 50, 100] for sigma in [1, 2, 3, 5]] for eta, sigma in test_points: sku = self.create_mock_sku(eta=eta, sigma=sigma) @@ -334,21 +337,15 @@ def test_mathematical_properties_summary(self): threshold_high = calculate_dynamic_similarity_threshold( sku_high_freq, kappa=kappa, beta=beta ) - threshold_low = calculate_dynamic_similarity_threshold( - sku_low_freq, kappa=kappa, beta=beta - ) + threshold_low = calculate_dynamic_similarity_threshold(sku_low_freq, kappa=kappa, beta=beta) self.assertGreater( - threshold_high, threshold_low, - msg="High-frequency SKU should have a higher threshold" + threshold_high, threshold_low, msg="High-frequency SKU should have a higher threshold" ) # Ensure a meaningful gap gap_ratio = (threshold_high - threshold_low) / threshold_low - self.assertGreater( - gap_ratio, 0.01, - msg="Threshold gap should be > 1%" - ) + self.assertGreater(gap_ratio, 0.01, msg="Threshold gap should be > 1%") class TestThresholdIntegrationWithStrategyCache(unittest.TestCase): diff --git a/geaflow-ai/plugins/lightmem/.gitignore b/geaflow-ai/plugins/lightmem/.gitignore new file mode 100644 index 000000000..74cece03c --- /dev/null +++ b/geaflow-ai/plugins/lightmem/.gitignore @@ -0,0 +1,19 @@ +# Byte-compiled / optimized files +__pycache__/ +*.py[cod] +.pytest_cache/ +geaflow_ai_lightmem.egg-info/* + +# Environment variables +.env + +# Virtual environment +.venv/ +uv.lock + +# Logs +/logs/ + +# IDE / OS specific +.vscode/ +.DS_Store diff --git a/geaflow-ai/plugins/lightmem/api/__init__.py b/geaflow-ai/plugins/lightmem/api/__init__.py new file mode 100644 index 000000000..bc4625c29 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/api/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""LightMem HTTP API (FastAPI).""" diff --git a/geaflow-ai/plugins/lightmem/api/app.py b/geaflow-ai/plugins/lightmem/api/app.py new file mode 100644 index 000000000..f6f618398 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/api/app.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any +import uuid + +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse +from pydantic import BaseModel, ConfigDict, Field + +from lightmem.memory_kernel import MemoryKernel +from api.envelope import Envelope + +app = FastAPI(title="GeaFlow AI LightMem", version="0.1.0") + +_KERNEL = MemoryKernel() + + +class _Message(BaseModel): + model_config = ConfigDict(extra="ignore") + + role: str + content: str + + +class MemoryWritePayload(BaseModel): + model_config = ConfigDict(extra="ignore") + + messages: list[_Message] = Field(default_factory=list) + mode: str | None = None + + +class MemoryRecallPayload(BaseModel): + model_config = ConfigDict(extra="ignore") + + query: str + limit: int | None = None + + +@app.exception_handler(RequestValidationError) +async def _validation_exception_handler(_: Request, exc: RequestValidationError) -> JSONResponse: + trace_id = f"tr_{uuid.uuid4().hex}" + return JSONResponse( + status_code=422, + content={ + "ok": False, + "error": {"code": "INVALID_REQUEST", "message": str(exc)}, + "trace": {"trace_id": trace_id}, + }, + ) + + +@app.exception_handler(Exception) +async def _unhandled_exception_handler(_: Request, exc: Exception) -> JSONResponse: + trace_id = f"tr_{uuid.uuid4().hex}" + return JSONResponse( + status_code=500, + content={ + "ok": False, + "error": { + "code": "INTERNAL_ERROR", + "message": f"internal error ({type(exc).__name__})", + }, + "trace": {"trace_id": trace_id}, + }, + ) + + +@app.get("/health") +async def health() -> dict[str, Any]: + return {"status": "UP"} + + +@app.post("/memory/write") +async def memory_write(envelope: Envelope) -> JSONResponse: + trace = envelope.trace.with_defaults() + env = envelope.model_copy(update={"trace": trace}) + + payload = MemoryWritePayload.model_validate(env.payload) + result = _KERNEL.write( + messages=[m.model_dump() for m in payload.messages], + scope=env.scope.model_dump(exclude_none=True), + mode=payload.mode or "echo", + ) + provenance = dict(result.get("provenance") or {}) + provenance["trace_id"] = trace.trace_id + + return JSONResponse( + status_code=200, + content={ + "ok": True, + "api_version": env.api_version, + "trace": trace.model_dump(), + "payload": {"actions": result.get("actions") or [], "provenance": provenance}, + }, + ) + + +@app.post("/memory/recall") +async def memory_recall(envelope: Envelope) -> JSONResponse: + trace = envelope.trace.with_defaults() + env = envelope.model_copy(update={"trace": trace}) + + payload = MemoryRecallPayload.model_validate(env.payload) + result = _KERNEL.recall( + query=payload.query, + scope=env.scope.model_dump(exclude_none=True), + limit=payload.limit or 5, + ) + result["trace_id"] = trace.trace_id + + return JSONResponse( + status_code=200, + content={ + "ok": True, + "api_version": env.api_version, + "trace": trace.model_dump(), + "payload": result, + }, + ) diff --git a/geaflow-ai/plugins/lightmem/api/envelope.py b/geaflow-ai/plugins/lightmem/api/envelope.py new file mode 100644 index 000000000..c06157aa8 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/api/envelope.py @@ -0,0 +1,67 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field, model_validator + + +class Scope(BaseModel): + """Hard boundary for all memory reads/writes.""" + + model_config = ConfigDict(extra="ignore") + + tenant_id: str | None = None + user_id: str | None = None + agent_id: str | None = None + run_id: str | None = None + actor_id: str | None = None + + @model_validator(mode="after") + def _scope_required(self) -> Scope: + has_any = any((self.tenant_id, self.user_id, self.agent_id, self.run_id, self.actor_id)) + if not has_any: + raise ValueError("scope is required (tenant_id/user_id/agent_id/run_id/actor_id)") + return self + + +class Trace(BaseModel): + model_config = ConfigDict(extra="ignore") + + trace_id: str | None = None + timestamp: float | None = None + caller: str | None = None + + def with_defaults(self) -> Trace: + return Trace( + trace_id=self.trace_id or f"tr_{uuid.uuid4().hex}", + timestamp=self.timestamp if self.timestamp is not None else time.time(), + caller=self.caller, + ) + + +class Envelope(BaseModel): + model_config = ConfigDict(extra="ignore") + + api_version: str = "v1" + scope: Scope + trace: Trace = Field(default_factory=Trace) + payload: dict[str, Any] = Field(default_factory=dict) diff --git a/geaflow-ai/plugins/lightmem/api/py.typed b/geaflow-ai/plugins/lightmem/api/py.typed new file mode 100644 index 000000000..df6030a22 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/api/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561 to indicate this package is typed. + diff --git a/geaflow-ai/plugins/lightmem/core/__init__.py b/geaflow-ai/plugins/lightmem/core/__init__.py new file mode 100644 index 000000000..8db9f7084 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""LightMem: minimal memory kernel (ledger + views).""" diff --git a/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/PKG-INFO b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/PKG-INFO new file mode 100644 index 000000000..a2e14fcd3 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/PKG-INFO @@ -0,0 +1,8 @@ +Metadata-Version: 2.4 +Name: geaflow-ai-lightmem +Version: 0.1.0 +Summary: LightMem: ledger + views memory kernel +Requires-Python: >=3.10 +Requires-Dist: numpy>=2.0.0 +Provides-Extra: dev +Requires-Dist: pytest>=8.0.0; extra == "dev" diff --git a/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/SOURCES.txt b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/SOURCES.txt new file mode 100644 index 000000000..0f955ab23 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/SOURCES.txt @@ -0,0 +1,12 @@ +pyproject.toml +core/geaflow_ai_lightmem.egg-info/PKG-INFO +core/geaflow_ai_lightmem.egg-info/SOURCES.txt +core/geaflow_ai_lightmem.egg-info/dependency_links.txt +core/geaflow_ai_lightmem.egg-info/requires.txt +core/geaflow_ai_lightmem.egg-info/top_level.txt +core/lightmem/__init__.py +core/lightmem/ledger.py +core/lightmem/memory_kernel.py +core/lightmem/types.py +core/lightmem/views.py +tests/test_write_recall_provenance.py \ No newline at end of file diff --git a/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/dependency_links.txt b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/dependency_links.txt new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/requires.txt b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/requires.txt new file mode 100644 index 000000000..e2343a768 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/requires.txt @@ -0,0 +1,4 @@ +numpy>=2.0.0 + +[dev] +pytest>=8.0.0 diff --git a/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/top_level.txt b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/top_level.txt new file mode 100644 index 000000000..bcaf9d40e --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/geaflow_ai_lightmem.egg-info/top_level.txt @@ -0,0 +1 @@ +lightmem diff --git a/geaflow-ai/plugins/lightmem/core/ledger.py b/geaflow-ai/plugins/lightmem/core/ledger.py new file mode 100644 index 000000000..aa8171650 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/ledger.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from dataclasses import replace +from typing import Iterable + +from lightmem.types import Action, LedgerEvent, MemoryUnit + + +class Ledger: + """Append-only event log. Authoritative source of truth.""" + + def __init__(self) -> None: + self._events: list[LedgerEvent] = [] + self._units: dict[str, MemoryUnit] = {} + + def append( + self, + *, + scope: dict[str, str], + action: Action, + candidate_set_snapshot: list[str] | None = None, + ) -> LedgerEvent: + event = LedgerEvent( + event_id=LedgerEvent.new_id(), + timestamp=LedgerEvent.now_ts(), + scope=dict(scope), + action=action, + candidate_set_snapshot=list(candidate_set_snapshot) if candidate_set_snapshot else None, + ) + self._events.append(event) + + if action.event_type == action.event_type.ADD: + if not action.unit_id or not action.content: + raise ValueError("ADD requires unit_id and content") + unit = MemoryUnit( + id=action.unit_id, + content=action.content, + metadata={"scope": dict(scope)}, + last_event_id=event.event_id, + ) + self._units[unit.id] = unit + elif action.event_type == action.event_type.UPDATE: + if not action.unit_id or action.content is None: + raise ValueError("UPDATE requires unit_id and content") + existing = self._units.get(action.unit_id) + if existing is None: + raise KeyError(f"unknown unit_id={action.unit_id}") + updated = replace(existing, content=action.content, last_event_id=event.event_id) + self._units[updated.id] = updated + elif action.event_type == action.event_type.DELETE: + if not action.unit_id: + raise ValueError("DELETE requires unit_id") + self._units.pop(action.unit_id, None) + elif action.event_type == action.event_type.NONE: + pass + + return event + + def get_unit(self, unit_id: str) -> MemoryUnit | None: + return self._units.get(unit_id) + + def list_units(self) -> list[MemoryUnit]: + return list(self._units.values()) + + def iter_events(self) -> Iterable[LedgerEvent]: + return iter(self._events) diff --git a/geaflow-ai/plugins/lightmem/core/memory_kernel.py b/geaflow-ai/plugins/lightmem/core/memory_kernel.py new file mode 100644 index 000000000..a5aec2f99 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/memory_kernel.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid +from dataclasses import asdict +from typing import Any + +from lightmem.ledger import Ledger +from lightmem.types import Action, ActionType, MemoryUnit, RecallResult +from lightmem.views import ViewManager + + +def _normalize_scope(scope: dict[str, Any]) -> dict[str, str]: + return {k: str(v) for k, v in scope.items() if v is not None and str(v).strip()} + + +class MemoryKernel: + def __init__(self) -> None: + self.ledger = Ledger() + self.views = ViewManager() + + def write( + self, + *, + messages: list[dict[str, Any]], + scope: dict[str, Any], + mode: str = "echo", + ) -> dict[str, Any]: + scope_norm = _normalize_scope(scope) + if not scope_norm: + raise ValueError("scope is required") + + actions: list[Action] = [] + + if mode not in ("echo", "llm"): + mode = "echo" + + if mode == "echo": + for msg in messages: + if msg.get("role") != "user": + continue + content = str(msg.get("content") or "").strip() + if not content: + continue + unit_id = f"mem_{uuid.uuid4().hex}" + actions.append(Action(event_type=ActionType.ADD, unit_id=unit_id, content=content)) + else: + # Placeholder for protocol-compliant LLM pipeline (unitize/ground/decide/commit). + actions = [] + + event_ids: list[str] = [] + for action in actions: + event = self.ledger.append(scope=scope_norm, action=action, candidate_set_snapshot=None) + event_ids.append(event.event_id) + unit = self.ledger.get_unit(action.unit_id or "") + if unit is not None: + self.views.vector_view.upsert(unit) + + return { + "actions": [asdict(a) for a in actions], + "provenance": { + "event_ids": event_ids, + "unitize_count": len(actions), + }, + } + + def recall( + self, + *, + query: str, + scope: dict[str, Any], + limit: int = 5, + ) -> dict[str, Any]: + scope_norm = _normalize_scope(scope) + if not scope_norm: + raise ValueError("scope is required") + + result: RecallResult = self.views.vector_view.recall( + query=query, scope_filter=scope_norm, limit=limit + ) + + units = [] + provenance = [] + for hit in result.hits: + unit: MemoryUnit = hit.unit + units.append( + { + "id": unit.id, + "content": unit.content, + "metadata": unit.metadata, + "last_event_id": unit.last_event_id, + } + ) + provenance.append( + { + "memory_id": unit.id, + "view_source": hit.view_source, + "last_event_id": unit.last_event_id, + "similarity": hit.similarity, + } + ) + + return {"units": units, "provenance": provenance} diff --git a/geaflow-ai/plugins/lightmem/core/py.typed b/geaflow-ai/plugins/lightmem/core/py.typed new file mode 100644 index 000000000..df6030a22 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561 to indicate this package is typed. + diff --git a/geaflow-ai/plugins/lightmem/core/types.py b/geaflow-ai/plugins/lightmem/core/types.py new file mode 100644 index 000000000..2653bb937 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/types.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import time +import uuid +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +import numpy as np + + +class ActionType(str, Enum): + ADD = "ADD" + UPDATE = "UPDATE" + DELETE = "DELETE" + NONE = "NONE" + + +@dataclass(frozen=True) +class Action: + event_type: ActionType + unit_id: str | None = None + content: str | None = None + previous_content: str | None = None + + +@dataclass +class MemoryUnit: + id: str + content: str + metadata: dict[str, Any] = field(default_factory=dict) + embedding: np.ndarray | None = None + last_event_id: str | None = None + + +@dataclass(frozen=True) +class LedgerEvent: + event_id: str + timestamp: float + scope: dict[str, str] + action: Action + candidate_set_snapshot: list[str] | None = None + + @staticmethod + def new_id() -> str: + return f"evt_{uuid.uuid4().hex}" + + @staticmethod + def now_ts() -> float: + return time.time() + + +@dataclass(frozen=True) +class RecallHit: + unit: MemoryUnit + similarity: float + view_source: str + + +@dataclass(frozen=True) +class RecallResult: + hits: list[RecallHit] diff --git a/geaflow-ai/plugins/lightmem/core/views.py b/geaflow-ai/plugins/lightmem/core/views.py new file mode 100644 index 000000000..e2e51827b --- /dev/null +++ b/geaflow-ai/plugins/lightmem/core/views.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import hashlib +from dataclasses import replace +from typing import Any + +import numpy as np + +from lightmem.types import ActionType, MemoryUnit, RecallHit, RecallResult + + +class _LocalEmbedder: + def __init__(self, dimension: int = 64) -> None: + self._dimension = max(8, int(dimension)) + + def embed(self, text: str) -> np.ndarray: + digest = hashlib.sha256(text.encode("utf-8")).digest() + raw = np.frombuffer(digest, dtype=np.uint8).astype(np.float32) + vec = np.resize(raw, self._dimension) + norm = np.linalg.norm(vec) + return vec if norm == 0 else vec / norm + + +def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + denom = np.linalg.norm(a) * np.linalg.norm(b) + if denom == 0: + return 0.0 + return float(np.dot(a, b) / denom) + + +def _scope_matches(scope_filter: dict[str, str], unit_scope: dict[str, str]) -> bool: + for k, v in scope_filter.items(): + if unit_scope.get(k) != v: + return False + return True + + +class VectorView: + """In-memory vector index with scope filtering.""" + + def __init__(self, embedder: _LocalEmbedder | None = None) -> None: + self._embedder = embedder or _LocalEmbedder() + self._units: dict[str, MemoryUnit] = {} + self._vectors: dict[str, np.ndarray] = {} + + def upsert(self, unit: MemoryUnit) -> None: + vector = ( + unit.embedding if unit.embedding is not None else self._embedder.embed(unit.content) + ) + self._units[unit.id] = replace(unit, embedding=vector) + self._vectors[unit.id] = vector + + def delete(self, unit_id: str) -> None: + self._units.pop(unit_id, None) + self._vectors.pop(unit_id, None) + + def apply_action(self, action_type: ActionType, unit: MemoryUnit | None) -> None: + if action_type in (ActionType.ADD, ActionType.UPDATE): + if unit is None: + return + self.upsert(unit) + elif action_type == ActionType.DELETE: + if unit is None: + return + self.delete(unit.id) + + def recall(self, *, query: str, scope_filter: dict[str, str], limit: int = 5) -> RecallResult: + q = self._embedder.embed(query) + hits: list[RecallHit] = [] + for unit_id, unit in self._units.items(): + unit_scope: dict[str, Any] = unit.metadata.get("scope") or {} + unit_scope = {k: str(v) for k, v in unit_scope.items()} + if not _scope_matches(scope_filter, unit_scope): + continue + vec = self._vectors.get(unit_id) + if vec is None: + continue + hits.append( + RecallHit( + unit=unit, + similarity=_cosine_similarity(q, vec), + view_source="vector", + ) + ) + hits.sort(key=lambda h: h.similarity, reverse=True) + return RecallResult(hits=hits[: max(0, int(limit))]) + + +class ViewManager: + def __init__(self) -> None: + self.vector_view = VectorView() diff --git a/geaflow-ai/plugins/lightmem/pyproject.toml b/geaflow-ai/plugins/lightmem/pyproject.toml new file mode 100644 index 000000000..23c939a38 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "geaflow-ai-lightmem" +version = "0.1.0" +description = "LightMem: ledger + views memory kernel" +requires-python = ">=3.10" +dependencies = [ + "numpy>=2.0.0", + "fastapi>=0.115.0", + "uvicorn>=0.30.0", + "pydantic>=2.6.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0.0", + "httpx>=0.27.0", + "ruff>=0.14.0", + "mypy>=1.10.0", +] + +[build-system] +requires = ["setuptools>=69.0.0", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +packages = ["lightmem", "api"] +package-dir = {"lightmem" = "core", "api" = "api"} + +[tool.setuptools.package-data] +lightmem = ["py.typed"] +api = ["py.typed"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] + +[tool.ruff] +line-length = 100 +target-version = "py310" + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "auto" + +[tool.mypy] +python_version = "3.11" +show_error_codes = true +pretty = true diff --git a/geaflow-ai/plugins/lightmem/scripts/smoke.sh b/geaflow-ai/plugins/lightmem/scripts/smoke.sh new file mode 100755 index 000000000..a6f79bcc0 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/scripts/smoke.sh @@ -0,0 +1,157 @@ +#!/usr/bin/env bash +set -euo pipefail + +# One-click smoke test for LightMem service. +# +# What it does: +# 1) Ensures `.venv` exists and runs `uv sync` into it (no activation required). +# 2) Starts `uvicorn` on HOST:PORT (unless a healthy service is already running there). +# 3) Runs HTTP checks for `/health`, `POST /memory/write`, and `POST /memory/recall`. +# 4) Shuts down the service if this script started it. + +HOST="${GEAFLOW_AI_LIGHTMEM_HOST:-127.0.0.1}" +PORT="${GEAFLOW_AI_LIGHTMEM_PORT:-5002}" +TIMEOUT_SECONDS="${GEAFLOW_AI_LIGHTMEM_TIMEOUT_SECONDS:-20}" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PLUGIN_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" +LOG_DIR="${PLUGIN_DIR}/logs" +LOG_FILE="${LOG_DIR}/smoke_lightmem.log" + +mkdir -p "${LOG_DIR}" +cd "${PLUGIN_DIR}" + +if [[ ! -d .venv ]]; then + python3.11 -m venv .venv +fi + +VENV_PY="${PLUGIN_DIR}/.venv/bin/python" +if [[ ! -x "${VENV_PY}" ]]; then + echo "[smoke] ERROR: venv python not found at ${VENV_PY}" + exit 1 +fi + +cleanup() { + if [[ "${STARTED_HERE:-false}" == "true" && -n "${SERVER_PID:-}" ]]; then + kill "${SERVER_PID}" >/dev/null 2>&1 || true + for _ in {1..50}; do + if kill -0 "${SERVER_PID}" >/dev/null 2>&1; then + sleep 0.1 + else + break + fi + done + fi +} +trap cleanup EXIT + +echo "[smoke] Sync deps (lightmem venv) ..." +uv sync --extra dev + +BASE_URL="http://${HOST}:${PORT}" +STARTED_HERE="false" +SERVER_PID="" + +health_ok() { + local body + body="$(curl -fsS "${BASE_URL}/health" 2>/dev/null || true)" + if [[ -z "${body}" ]]; then + return 1 + fi + "${VENV_PY}" -c 'import json,sys; obj=json.loads(sys.argv[1]); assert obj.get("status") == "UP"' \ + "${body}" >/dev/null 2>&1 +} + +if health_ok; then + echo "[smoke] LightMem already running at ${BASE_URL}" +else + echo "[smoke] Starting LightMem at ${BASE_URL} ..." + "${VENV_PY}" -m uvicorn api.app:app --host "${HOST}" --port "${PORT}" --log-level info >"${LOG_FILE}" 2>&1 & + SERVER_PID="$!" + disown "${SERVER_PID}" >/dev/null 2>&1 || true + STARTED_HERE="true" + + deadline=$(( $(date +%s) + TIMEOUT_SECONDS )) + while true; do + if health_ok; then + break + fi + if [[ $(date +%s) -ge ${deadline} ]]; then + echo "[smoke] ERROR: LightMem did not become healthy within ${TIMEOUT_SECONDS}s" + echo "[smoke] Log: ${LOG_FILE}" + tail -n 200 "${LOG_FILE}" || true + exit 1 + fi + sleep 0.2 + done +fi + +echo "[smoke] Health OK" + +post_json() { + local path="$1" + local body="$2" + local expected_code="$3" + + local tmp + tmp="$(mktemp)" + local code + code="$(curl -sS -o "${tmp}" -w "%{http_code}" -X POST "${BASE_URL}${path}" \ + -H "Content-Type: application/json" \ + -d "${body}" || true)" + local resp + resp="$(cat "${tmp}")" + rm -f "${tmp}" + + if [[ "${code}" != "${expected_code}" ]]; then + echo "[smoke] ERROR: POST ${path} expected HTTP ${expected_code}, got ${code}" + echo "[smoke] Response:" + echo "${resp}" + exit 1 + fi + + printf "%s" "${resp}" +} + +echo "[smoke] Test: empty scope rejected (expect 422)" +resp="$(post_json "/memory/recall" '{ + "api_version":"v1", + "scope":{}, + "trace":{}, + "payload":{"query":"hi","limit":5} +}' "422")" +printf "%s" "${resp}" | "${VENV_PY}" -c 'import json,sys; obj=json.load(sys.stdin); assert obj.get("ok") is False' +echo "[smoke] OK" + +echo "[smoke] Test: memory.write (echo)" +resp="$(post_json "/memory/write" '{ + "api_version":"v1", + "scope":{"user_id":"u_smoke"}, + "trace":{}, + "payload":{ + "mode":"echo", + "messages":[{"role":"user","content":"I like coffee."}] + } +}' "200")" +printf "%s" "${resp}" | "${VENV_PY}" -c \ + 'import json,sys; obj=json.load(sys.stdin); assert obj.get("ok") is True; p=obj.get("payload") or {}; actions=p.get("actions") or []; assert len(actions) >= 1' +echo "[smoke] OK" + +echo "[smoke] Test: memory.recall" +resp="$(post_json "/memory/recall" '{ + "api_version":"v1", + "scope":{"user_id":"u_smoke"}, + "trace":{}, + "payload":{"query":"coffee","limit":5} +}' "200")" +printf "%s" "${resp}" | "${VENV_PY}" -c \ + 'import json,sys; obj=json.load(sys.stdin); assert obj.get("ok") is True; p=obj.get("payload") or {}; units=p.get("units") or []; assert len(units) >= 1' +echo "[smoke] OK" + +echo +echo "[smoke] SMOKE OK: lightmem service" +echo "[smoke] Base URL: ${BASE_URL}" +if [[ "${STARTED_HERE}" == "true" ]]; then + echo "[smoke] Log: ${LOG_FILE}" +fi + diff --git a/geaflow-ai/plugins/lightmem/tests/test_api.py b/geaflow-ai/plugins/lightmem/tests/test_api.py new file mode 100644 index 000000000..9e675ca1e --- /dev/null +++ b/geaflow-ai/plugins/lightmem/tests/test_api.py @@ -0,0 +1,83 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import uuid + +from fastapi.testclient import TestClient + +from api.app import app + + +def test_health_ok() -> None: + client = TestClient(app) + res = client.get("/health") + assert res.status_code == 200 + assert res.json() == {"status": "UP"} + + +def test_empty_scope_rejected() -> None: + client = TestClient(app) + res = client.post( + "/memory/recall", + json={ + "api_version": "v1", + "scope": {}, + "trace": {}, + "payload": {"query": "hi", "limit": 5}, + }, + ) + assert res.status_code == 422 + body = res.json() + assert body["ok"] is False + + +def test_write_then_recall() -> None: + user_id = f"u_{uuid.uuid4().hex}" + client = TestClient(app) + + write_res = client.post( + "/memory/write", + json={ + "api_version": "v1", + "scope": {"user_id": user_id}, + "trace": {}, + "payload": { + "mode": "echo", + "messages": [{"role": "user", "content": "I like coffee."}], + }, + }, + ) + assert write_res.status_code == 200 + write_body = write_res.json() + assert write_body["ok"] is True + assert write_body["payload"]["actions"] + + recall_res = client.post( + "/memory/recall", + json={ + "api_version": "v1", + "scope": {"user_id": user_id}, + "trace": {}, + "payload": {"query": "coffee", "limit": 5}, + }, + ) + assert recall_res.status_code == 200 + recall_body = recall_res.json() + assert recall_body["ok"] is True + assert recall_body["payload"]["units"] diff --git a/geaflow-ai/plugins/lightmem/tests/test_write_recall_provenance.py b/geaflow-ai/plugins/lightmem/tests/test_write_recall_provenance.py new file mode 100644 index 000000000..7877d8ae8 --- /dev/null +++ b/geaflow-ai/plugins/lightmem/tests/test_write_recall_provenance.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from lightmem.memory_kernel import MemoryKernel + + +def test_write_then_recall_has_provenance() -> None: + kernel = MemoryKernel() + write_res = kernel.write( + messages=[{"role": "user", "content": "I prefer dark mode."}], + scope={"user_id": "u1"}, + mode="echo", + ) + assert "actions" in write_res + assert write_res["provenance"]["event_ids"] + + recall_res = kernel.recall(query="preferences", scope={"user_id": "u1"}, limit=5) + assert recall_res["units"] + assert recall_res["provenance"] + assert recall_res["provenance"][0]["last_event_id"] is not None diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java index d39123183..c83514ab0 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java @@ -19,6 +19,7 @@ package org.apache.geaflow.ai; +import com.google.gson.Gson; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -27,7 +28,16 @@ import org.apache.geaflow.ai.graph.*; import org.apache.geaflow.ai.graph.io.*; import org.apache.geaflow.ai.index.EntityAttributeIndexStore; +import org.apache.geaflow.ai.index.vector.CastsVector; import org.apache.geaflow.ai.index.vector.KeywordVector; +import org.apache.geaflow.ai.protocol.AiScope; +import org.apache.geaflow.ai.memory.MemoryClient; +import org.apache.geaflow.ai.memory.MemoryRecallRequest; +import org.apache.geaflow.ai.memory.MemoryRecallResponse; +import org.apache.geaflow.ai.memory.MemoryWriteRequest; +import org.apache.geaflow.ai.memory.MemoryWriteResponse; +import org.apache.geaflow.ai.memory.LightMemConfig; +import org.apache.geaflow.ai.memory.LightMemRestClient; import org.apache.geaflow.ai.search.VectorSearch; import org.apache.geaflow.ai.service.ServerMemoryCache; import org.apache.geaflow.ai.verbalization.Context; @@ -45,6 +55,8 @@ public class GeaFlowMemoryServer { private static final String SERVER_NAME = "geaflow-memory-server"; private static final int DEFAULT_PORT = 8080; + private static final Gson GSON = new Gson(); + private static final ServerMemoryCache CACHE = new ServerMemoryCache(); public static void main(String[] args) { @@ -218,6 +230,75 @@ public String execQuery(@Param("sessionId") String sessionId, return context.toString(); } + @Post + @Mapping("/query/castsExec") + public String execCastsQuery(@Param("sessionId") String sessionId, + @Param("maxDepth") Integer maxDepth, + @Body String query) { + String graphName = CACHE.getGraphNameBySession(sessionId); + if (graphName == null) { + throw new RuntimeException("Graph not exist."); + } + GraphMemoryServer server = CACHE.getServerByName(graphName); + int depth = maxDepth == null ? 5 : maxDepth; + VectorSearch search = new VectorSearch(null, sessionId); + // Seed the starting subgraph by keyword recall, then let CASTS traverse multi-hop. + search.addVector(new KeywordVector(query)); + search.addVector(new CastsVector(query, depth)); + server.search(search); + Context context = server.verbalize(sessionId, + new SubgraphSemanticPromptFunction(server.getGraphAccessors().get(0))); + return context.toString(); + } + + @Post + @Mapping("/memory/write") + public String memoryWrite( + @Param("tenantId") String tenantId, + @Param("userId") String userId, + @Param("agentId") String agentId, + @Param("runId") String runId, + @Param("actorId") String actorId, + @Param("sessionId") String sessionId, + @Body String input + ) { + String effectiveRunId = isBlank(runId) ? normalized(sessionId) : normalized(runId); + AiScope scope = buildScope(tenantId, userId, agentId, effectiveRunId, actorId); + validateScope(scope); + + MemoryWriteRequest request = isBlank(input) ? new MemoryWriteRequest() : GSON.fromJson(input, MemoryWriteRequest.class); + if (request == null) { + request = new MemoryWriteRequest(); + } + MemoryClient client = new MemoryClient(new LightMemRestClient(LightMemConfig.fromEnv()), SERVER_NAME); + MemoryWriteResponse resp = client.write(scope, request); + return GSON.toJson(resp); + } + + @Post + @Mapping("/memory/recall") + public String memoryRecall( + @Param("tenantId") String tenantId, + @Param("userId") String userId, + @Param("agentId") String agentId, + @Param("runId") String runId, + @Param("actorId") String actorId, + @Param("sessionId") String sessionId, + @Body String input + ) { + String effectiveRunId = isBlank(runId) ? normalized(sessionId) : normalized(runId); + AiScope scope = buildScope(tenantId, userId, agentId, effectiveRunId, actorId); + validateScope(scope); + + MemoryRecallRequest request = isBlank(input) ? null : GSON.fromJson(input, MemoryRecallRequest.class); + if (request == null || isBlank(request.query)) { + throw new RuntimeException("MemoryRecallRequest.query is required"); + } + MemoryClient client = new MemoryClient(new LightMemRestClient(LightMemConfig.fromEnv()), SERVER_NAME); + MemoryRecallResponse resp = client.recall(scope, request); + return GSON.toJson(resp); + } + @Post @Mapping("/query/result") public String getResult(@Param("sessionId") String sessionId) { @@ -229,4 +310,43 @@ public String getResult(@Param("sessionId") String sessionId) { List result = server.getSessionEntities(sessionId); return result.toString(); } + + private static AiScope buildScope( + String tenantId, + String userId, + String agentId, + String runId, + String actorId + ) { + AiScope scope = new AiScope(); + scope.tenantId = normalized(tenantId); + scope.userId = normalized(userId); + scope.agentId = normalized(agentId); + scope.runId = normalized(runId); + scope.actorId = normalized(actorId); + return scope; + } + + private static void validateScope(AiScope scope) { + boolean ok = !isBlank(scope.tenantId) + || !isBlank(scope.userId) + || !isBlank(scope.agentId) + || !isBlank(scope.runId) + || !isBlank(scope.actorId); + if (!ok) { + throw new RuntimeException("Scope is required (tenantId/userId/agentId/runId/actorId)"); + } + } + + private static String normalized(String value) { + if (value == null) { + return null; + } + String trimmed = value.trim(); + return trimmed.isEmpty() ? null : trimmed; + } + + private static boolean isBlank(String value) { + return value == null || value.trim().isEmpty(); + } } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/GraphMemoryServer.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/GraphMemoryServer.java index b571fe540..30c219742 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/GraphMemoryServer.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/GraphMemoryServer.java @@ -26,9 +26,13 @@ import java.util.stream.Collectors; import org.apache.geaflow.ai.graph.GraphAccessor; import org.apache.geaflow.ai.graph.GraphEntity; +import org.apache.geaflow.ai.casts.CastsConfig; +import org.apache.geaflow.ai.casts.CastsOperator; +import org.apache.geaflow.ai.casts.CastsRestClient; import org.apache.geaflow.ai.index.EmbeddingIndexStore; import org.apache.geaflow.ai.index.EntityAttributeIndexStore; import org.apache.geaflow.ai.index.IndexStore; +import org.apache.geaflow.ai.index.vector.VectorType; import org.apache.geaflow.ai.operator.EmbeddingOperator; import org.apache.geaflow.ai.operator.SearchOperator; import org.apache.geaflow.ai.operator.SessionOperator; @@ -81,16 +85,29 @@ public String search(VectorSearch search) { sessionManagement.createSession(sessionId); } - for (IndexStore indexStore : indexStores) { - if (indexStore instanceof EntityAttributeIndexStore) { - SessionOperator searchOperator = new SessionOperator(graphAccessors.get(0), indexStore); - applySearch(sessionId, searchOperator, search); + if (search.getVectorMap().containsKey(VectorType.KeywordVector)) { + for (IndexStore indexStore : indexStores) { + if (indexStore instanceof EntityAttributeIndexStore) { + SessionOperator searchOperator = new SessionOperator(graphAccessors.get(0), indexStore); + applySearch(sessionId, searchOperator, search); + } } - if (indexStore instanceof EmbeddingIndexStore) { - EmbeddingOperator embeddingOperator = new EmbeddingOperator(graphAccessors.get(0), indexStore); - applySearch(sessionId, embeddingOperator, search); + } + + if (search.getVectorMap().containsKey(VectorType.EmbeddingVector)) { + for (IndexStore indexStore : indexStores) { + if (indexStore instanceof EmbeddingIndexStore) { + EmbeddingOperator embeddingOperator = new EmbeddingOperator(graphAccessors.get(0), indexStore); + applySearch(sessionId, embeddingOperator, search); + } } } + + if (search.getVectorMap().containsKey(VectorType.CastsVector)) { + CastsRestClient castsClient = new CastsRestClient(CastsConfig.fromEnv()); + CastsOperator castsOperator = new CastsOperator(graphAccessors.get(0), castsClient); + applySearch(sessionId, castsOperator, search); + } return sessionId; } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsConfig.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsConfig.java new file mode 100644 index 000000000..5c41aa4e5 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsConfig.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.casts; + +public class CastsConfig { + + public final String baseUrl; + public final String token; + + public CastsConfig(String baseUrl, String token) { + this.baseUrl = baseUrl; + this.token = token == null ? "" : token; + } + + public static CastsConfig fromEnv() { + String url = System.getenv("GEAFLOW_AI_CASTS_URL"); + if (url == null || url.isEmpty()) { + url = "http://localhost:5001"; + } + String token = System.getenv("GEAFLOW_AI_CASTS_TOKEN"); + return new CastsConfig(url, token); + } + + public String decisionUrl() { + return join(baseUrl, "/casts/decision"); + } + + private static String join(String base, String path) { + if (base.endsWith("/") && path.startsWith("/")) { + return base.substring(0, base.length() - 1) + path; + } + if (!base.endsWith("/") && !path.startsWith("/")) { + return base + "/" + path; + } + return base + path; + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionParser.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionParser.java new file mode 100644 index 000000000..979ec87f0 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionParser.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.casts; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public class CastsDecisionParser { + + public enum Kind { + OUT, + IN, + BOTH, + OUT_E, + IN_E, + BOTH_E, + IN_V, + OUT_V, + OTHER_V, + DEDUP, + SIMPLE_PATH, + LIMIT, + ORDER_BY, + HAS, + STOP, + UNKNOWN + } + + public static class ParsedDecision { + public final Kind kind; + public final String label; + public final String raw; + + public ParsedDecision(Kind kind, String label, String raw) { + this.kind = kind; + this.label = label; + this.raw = raw; + } + } + + private static final Pattern STEP_WITH_LABEL = + Pattern.compile("^(out|in|both|outE|inE|bothE)\\('([^']+)'\\)$"); + + public static ParsedDecision parse(String decision) { + if (decision == null) { + return new ParsedDecision(Kind.UNKNOWN, null, null); + } + String d = decision.trim(); + if (d.isEmpty() || "stop".equals(d)) { + return new ParsedDecision(Kind.STOP, null, d); + } + + Matcher m = STEP_WITH_LABEL.matcher(d); + if (m.matches()) { + String op = m.group(1); + String label = m.group(2); + if ("out".equals(op)) { + return new ParsedDecision(Kind.OUT, label, d); + } else if ("in".equals(op)) { + return new ParsedDecision(Kind.IN, label, d); + } else if ("both".equals(op)) { + return new ParsedDecision(Kind.BOTH, label, d); + } else if ("outE".equals(op)) { + return new ParsedDecision(Kind.OUT_E, label, d); + } else if ("inE".equals(op)) { + return new ParsedDecision(Kind.IN_E, label, d); + } else if ("bothE".equals(op)) { + return new ParsedDecision(Kind.BOTH_E, label, d); + } + } + + if ("inV()".equals(d)) { + return new ParsedDecision(Kind.IN_V, null, d); + } + if ("outV()".equals(d)) { + return new ParsedDecision(Kind.OUT_V, null, d); + } + if ("otherV()".equals(d)) { + return new ParsedDecision(Kind.OTHER_V, null, d); + } + if ("dedup()".equals(d)) { + return new ParsedDecision(Kind.DEDUP, null, d); + } + if ("simplePath()".equals(d)) { + return new ParsedDecision(Kind.SIMPLE_PATH, null, d); + } + if (d.startsWith("limit(") && d.endsWith(")")) { + return new ParsedDecision(Kind.LIMIT, null, d); + } + if (d.startsWith("order().by(") && d.endsWith(")")) { + return new ParsedDecision(Kind.ORDER_BY, null, d); + } + if (d.startsWith("has(") && d.endsWith(")")) { + return new ParsedDecision(Kind.HAS, null, d); + } + + return new ParsedDecision(Kind.UNKNOWN, null, d); + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionRequest.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionRequest.java new file mode 100644 index 000000000..834a06925 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionRequest.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.casts; + +import com.google.gson.annotations.SerializedName; +import java.util.List; +import java.util.Map; + +public class CastsDecisionRequest { + + public String goal; + + @SerializedName("max_depth") + public Integer maxDepth; + + public Traversal traversal; + + public Node node; + + @SerializedName("graph_schema") + public GraphSchema graphSchema; + + public static class Traversal { + @SerializedName("structural_signature") + public String structuralSignature; + + @SerializedName("step_index") + public Integer stepIndex; + } + + public static class Node { + public String label; + public Map properties; + } + + public static class GraphSchema { + @SerializedName("schema_fingerprint") + public String schemaFingerprint; + + @SerializedName("valid_outgoing_labels") + public List validOutgoingLabels; + + @SerializedName("valid_incoming_labels") + public List validIncomingLabels; + + @SerializedName("schema_summary") + public String schemaSummary; + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionResponse.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionResponse.java new file mode 100644 index 000000000..9edc7496f --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsDecisionResponse.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.casts; + +import com.google.gson.annotations.SerializedName; +import java.util.Map; + +public class CastsDecisionResponse { + + public String decision; + + @SerializedName("match_type") + public String matchType; + + @SerializedName("sku_id") + public String skuId; + + public Map provenance; +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsOperator.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsOperator.java new file mode 100644 index 000000000..192891e72 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsOperator.java @@ -0,0 +1,449 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.casts; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.geaflow.ai.graph.GraphAccessor; +import org.apache.geaflow.ai.graph.GraphEdge; +import org.apache.geaflow.ai.graph.GraphEntity; +import org.apache.geaflow.ai.graph.GraphVertex; +import org.apache.geaflow.ai.graph.io.EdgeSchema; +import org.apache.geaflow.ai.graph.io.GraphSchema; +import org.apache.geaflow.ai.graph.io.Schema; +import org.apache.geaflow.ai.graph.io.VertexSchema; +import org.apache.geaflow.ai.index.vector.CastsVector; +import org.apache.geaflow.ai.index.vector.IVector; +import org.apache.geaflow.ai.index.vector.VectorType; +import org.apache.geaflow.ai.operator.SearchOperator; +import org.apache.geaflow.ai.protocol.AiScope; +import org.apache.geaflow.ai.protocol.AiTrace; +import org.apache.geaflow.ai.search.VectorSearch; +import org.apache.geaflow.ai.subgraph.SubGraph; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class CastsOperator implements SearchOperator { + + private static final Logger LOGGER = LoggerFactory.getLogger(CastsOperator.class); + + private final GraphAccessor graphAccessor; + private final CastsRestClient castsClient; + + public CastsOperator(GraphAccessor graphAccessor, CastsRestClient castsClient) { + this.graphAccessor = Objects.requireNonNull(graphAccessor); + this.castsClient = Objects.requireNonNull(castsClient); + } + + @Override + public List apply(List subGraphList, VectorSearch search) { + List castsVectors = search.getVectorMap().get(VectorType.CastsVector); + if (castsVectors == null || castsVectors.isEmpty()) { + return subGraphList == null ? new ArrayList<>() : new ArrayList<>(subGraphList); + } + + CastsVector castsVector = (CastsVector) castsVectors.get(0); + String goal = castsVector.getGoal(); + int maxDepth = castsVector.getMaxDepth() > 0 ? castsVector.getMaxDepth() : 5; + + String sessionId = search.getSessionId(); + String runId = sessionId; + + if (subGraphList == null || subGraphList.isEmpty()) { + return new ArrayList<>(); + } + + String schemaFingerprint = buildSchemaFingerprint(graphAccessor.getGraphSchema()); + String schemaSummary = buildSchemaSummary(graphAccessor.getGraphSchema()); + + List out = new ArrayList<>(subGraphList.size()); + for (SubGraph subGraph : subGraphList) { + out.add(applyOneSubgraph(subGraph, goal, maxDepth, runId, schemaFingerprint, schemaSummary)); + } + return out; + } + + private SubGraph applyOneSubgraph( + SubGraph subGraph, + String goal, + int maxDepth, + String runId, + String schemaFingerprint, + String schemaSummary + ) { + Set entitySet = new HashSet<>(subGraph.getGraphEntityList()); + + GraphVertex currentVertex = pickCurrentVertex(subGraph); + if (currentVertex == null) { + return subGraph; + } + + GraphEntity current = currentVertex; + String entryVertexIdForEdge = currentVertex.getVertex().getId(); + + String signature = "V()"; + AiTrace trace = AiTrace.newTrace("geaflow-ai"); + AiScope scope = AiScope.withRunId(runId); + + for (int step = 0; step < maxDepth; step++) { + CastsDecisionRequest request = buildRequest( + goal, + signature, + step, + current, + schemaFingerprint, + schemaSummary + ); + + CastsDecisionResponse resp; + try { + resp = castsClient.decision(scope, trace, request); + } catch (Exception e) { + LOGGER.warn("CASTS service call failed, stop traversal: {}", e.getMessage()); + break; + } + + if (resp == null || resp.decision == null) { + break; + } + + CastsDecisionParser.ParsedDecision parsed = CastsDecisionParser.parse(resp.decision); + if (parsed.kind == CastsDecisionParser.Kind.STOP || parsed.kind == CastsDecisionParser.Kind.UNKNOWN) { + break; + } + + boolean progressed; + TraversalResult tr = executeDecision(current, entryVertexIdForEdge, parsed); + progressed = tr.progressed; + if (!progressed) { + break; + } + + // Apply entities into the subgraph (dedupe). + for (GraphEntity e : tr.addEntities) { + if (entitySet.add(e)) { + subGraph.addEntity(e); + } + } + + // Move cursor. + current = tr.nextEntity; + if (current instanceof GraphVertex) { + entryVertexIdForEdge = ((GraphVertex) current).getVertex().getId(); + } else if (current instanceof GraphEdge) { + entryVertexIdForEdge = tr.entryVertexIdForEdge != null ? tr.entryVertexIdForEdge : entryVertexIdForEdge; + } + + signature = signature + "." + parsed.raw; + } + + return subGraph; + } + + private static class TraversalResult { + final boolean progressed; + final GraphEntity nextEntity; + final List addEntities; + final String entryVertexIdForEdge; + + TraversalResult(boolean progressed, GraphEntity nextEntity, List addEntities, String entryVertexIdForEdge) { + this.progressed = progressed; + this.nextEntity = nextEntity; + this.addEntities = addEntities; + this.entryVertexIdForEdge = entryVertexIdForEdge; + } + } + + private TraversalResult executeDecision(GraphEntity current, String entryVertexIdForEdge, CastsDecisionParser.ParsedDecision parsed) { + if (current instanceof GraphVertex) { + GraphVertex v = (GraphVertex) current; + String vid = v.getVertex().getId(); + List edges = expandEdges(v); + + if (parsed.kind == CastsDecisionParser.Kind.OUT) { + List outEdges = edges.stream() + .filter(e -> parsed.label.equals(e.getLabel())) + .filter(e -> vid.equals(e.getEdge().getSrcId())) + .collect(Collectors.toList()); + return traverseToVertexFromEdges(vid, outEdges, true); + } + + if (parsed.kind == CastsDecisionParser.Kind.IN) { + List inEdges = edges.stream() + .filter(e -> parsed.label.equals(e.getLabel())) + .filter(e -> vid.equals(e.getEdge().getDstId())) + .collect(Collectors.toList()); + return traverseToVertexFromEdges(vid, inEdges, false); + } + + if (parsed.kind == CastsDecisionParser.Kind.BOTH) { + List bothEdges = edges.stream() + .filter(e -> parsed.label.equals(e.getLabel())) + .filter(e -> vid.equals(e.getEdge().getSrcId()) || vid.equals(e.getEdge().getDstId())) + .collect(Collectors.toList()); + if (bothEdges.isEmpty()) { + return new TraversalResult(false, current, List.of(), null); + } + List add = new ArrayList<>(); + add.addAll(bothEdges); + GraphVertex next = null; + for (GraphEdge e : bothEdges) { + boolean isOut = vid.equals(e.getEdge().getSrcId()); + String nextId = isOut ? e.getEdge().getDstId() : e.getEdge().getSrcId(); + GraphVertex v2 = graphAccessor.getVertex(null, nextId); + if (v2 != null) { + add.add(v2); + if (next == null) { + next = v2; + } + } + } + if (next == null) { + return new TraversalResult(false, current, List.of(), null); + } + return new TraversalResult(true, next, add, vid); + } + + if (parsed.kind == CastsDecisionParser.Kind.OUT_E) { + GraphEdge next = edges.stream() + .filter(e -> parsed.label.equals(e.getLabel())) + .filter(e -> vid.equals(e.getEdge().getSrcId())) + .findFirst() + .orElse(null); + if (next == null) { + return new TraversalResult(false, current, List.of(), null); + } + return new TraversalResult(true, next, new ArrayList<>(List.of(next)), vid); + } + + if (parsed.kind == CastsDecisionParser.Kind.IN_E) { + GraphEdge next = edges.stream() + .filter(e -> parsed.label.equals(e.getLabel())) + .filter(e -> vid.equals(e.getEdge().getDstId())) + .findFirst() + .orElse(null); + if (next == null) { + return new TraversalResult(false, current, List.of(), null); + } + return new TraversalResult(true, next, new ArrayList<>(List.of(next)), vid); + } + + if (parsed.kind == CastsDecisionParser.Kind.BOTH_E) { + GraphEdge next = edges.stream() + .filter(e -> parsed.label.equals(e.getLabel())) + .filter(e -> vid.equals(e.getEdge().getSrcId()) || vid.equals(e.getEdge().getDstId())) + .findFirst() + .orElse(null); + if (next == null) { + return new TraversalResult(false, current, List.of(), null); + } + return new TraversalResult(true, next, new ArrayList<>(List.of(next)), vid); + } + + // Filter/modifier steps are treated as no-op for execution. + return new TraversalResult(true, current, List.of(), null); + } + + if (current instanceof GraphEdge) { + GraphEdge e = (GraphEdge) current; + String srcId = e.getEdge().getSrcId(); + String dstId = e.getEdge().getDstId(); + + if (parsed.kind == CastsDecisionParser.Kind.IN_V) { + GraphVertex next = graphAccessor.getVertex(null, dstId); + if (next == null) { + return new TraversalResult(false, current, List.of(), null); + } + return new TraversalResult(true, next, new ArrayList<>(List.of(next)), null); + } + if (parsed.kind == CastsDecisionParser.Kind.OUT_V) { + GraphVertex next = graphAccessor.getVertex(null, srcId); + if (next == null) { + return new TraversalResult(false, current, List.of(), null); + } + return new TraversalResult(true, next, new ArrayList<>(List.of(next)), null); + } + if (parsed.kind == CastsDecisionParser.Kind.OTHER_V) { + String nextId = entryVertexIdForEdge != null && entryVertexIdForEdge.equals(srcId) ? dstId : srcId; + GraphVertex next = graphAccessor.getVertex(null, nextId); + if (next == null) { + return new TraversalResult(false, current, List.of(), null); + } + return new TraversalResult(true, next, new ArrayList<>(List.of(next)), null); + } + + // Modifier steps as no-op. + return new TraversalResult(true, current, List.of(), null); + } + + return new TraversalResult(false, current, List.of(), null); + } + + private TraversalResult traverseToVertexFromEdges(String currentVertexId, List edges, boolean outgoing) { + if (edges == null || edges.isEmpty()) { + return new TraversalResult(false, null, List.of(), null); + } + List add = new ArrayList<>(); + add.addAll(edges); + + GraphVertex next = null; + for (GraphEdge e : edges) { + String nextId = outgoing ? e.getEdge().getDstId() : e.getEdge().getSrcId(); + GraphVertex v = graphAccessor.getVertex(null, nextId); + if (v != null) { + add.add(v); + if (next == null) { + next = v; + } + } + } + if (next == null) { + return new TraversalResult(false, null, List.of(), null); + } + return new TraversalResult(true, next, add, currentVertexId); + } + + private List expandEdges(GraphVertex v) { + List expanded = graphAccessor.expand(v); + List edges = new ArrayList<>(); + for (GraphEntity e : expanded) { + if (e instanceof GraphEdge) { + edges.add((GraphEdge) e); + } + } + return edges; + } + + private GraphVertex pickCurrentVertex(SubGraph subGraph) { + List entities = subGraph.getGraphEntityList(); + for (int i = entities.size() - 1; i >= 0; i--) { + GraphEntity e = entities.get(i); + if (e instanceof GraphVertex) { + return (GraphVertex) e; + } + } + return null; + } + + private CastsDecisionRequest buildRequest( + String goal, + String signature, + int stepIndex, + GraphEntity current, + String schemaFingerprint, + String schemaSummary + ) { + CastsDecisionRequest req = new CastsDecisionRequest(); + req.goal = goal; + req.traversal = new CastsDecisionRequest.Traversal(); + req.traversal.structuralSignature = signature; + req.traversal.stepIndex = stepIndex; + + req.node = new CastsDecisionRequest.Node(); + req.node.properties = extractProperties(current); + req.node.label = current.getLabel(); + + if (!req.node.properties.containsKey("type")) { + req.node.properties.put("type", current.getLabel()); + } + + List outgoing = new ArrayList<>(); + List incoming = new ArrayList<>(); + if (current instanceof GraphVertex) { + GraphVertex v = (GraphVertex) current; + String vid = v.getVertex().getId(); + for (GraphEdge e : expandEdges(v)) { + if (vid.equals(e.getEdge().getSrcId())) { + outgoing.add(e.getLabel()); + } else if (vid.equals(e.getEdge().getDstId())) { + incoming.add(e.getLabel()); + } + } + } + + req.graphSchema = new CastsDecisionRequest.GraphSchema(); + req.graphSchema.schemaFingerprint = schemaFingerprint; + req.graphSchema.schemaSummary = schemaSummary; + req.graphSchema.validOutgoingLabels = outgoing.stream().distinct().sorted().collect(Collectors.toList()); + req.graphSchema.validIncomingLabels = incoming.stream().distinct().sorted().collect(Collectors.toList()); + + return req; + } + + private Map extractProperties(GraphEntity entity) { + Map props = new HashMap<>(); + if (entity instanceof GraphVertex) { + GraphVertex v = (GraphVertex) entity; + Schema schema = graphAccessor.getGraphSchema().getSchema(v.getLabel()); + if (schema instanceof VertexSchema) { + List fields = ((VertexSchema) schema).getFields(); + List values = v.getVertex().getValues(); + for (int i = 0; i < Math.min(fields.size(), values.size()); i++) { + props.put(fields.get(i), values.get(i)); + } + } + props.put("id", v.getVertex().getId()); + } else if (entity instanceof GraphEdge) { + GraphEdge e = (GraphEdge) entity; + Schema schema = graphAccessor.getGraphSchema().getSchema(e.getLabel()); + if (schema instanceof EdgeSchema) { + List fields = ((EdgeSchema) schema).getFields(); + List values = e.getEdge().getValues(); + for (int i = 0; i < Math.min(fields.size(), values.size()); i++) { + props.put(fields.get(i), values.get(i)); + } + } + props.put("srcId", e.getEdge().getSrcId()); + props.put("dstId", e.getEdge().getDstId()); + } + return props; + } + + private static String buildSchemaFingerprint(GraphSchema schema) { + if (schema == null) { + return "schema_unknown"; + } + String name = schema.getName(); + return "schema_" + (name == null ? "unknown" : name); + } + + private static String buildSchemaSummary(GraphSchema schema) { + if (schema == null) { + return ""; + } + List nodes = schema.getVertexSchemaList().stream() + .map(VertexSchema::getLabel) + .distinct() + .sorted() + .collect(Collectors.toList()); + List edges = schema.getEdgeSchemaList().stream() + .map(EdgeSchema::getLabel) + .distinct() + .sorted() + .collect(Collectors.toList()); + return "node_types=" + nodes + ", edge_labels=" + edges; + } +} diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsRestClient.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsRestClient.java new file mode 100644 index 000000000..04e1b69fd --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/casts/CastsRestClient.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.casts; + +import java.util.Objects; +import org.apache.geaflow.ai.protocol.AiEnvelope; +import org.apache.geaflow.ai.protocol.AiHttpClient; +import org.apache.geaflow.ai.protocol.AiScope; +import org.apache.geaflow.ai.protocol.AiTrace; + +public class CastsRestClient { + + private final AiHttpClient client; + + public CastsRestClient(CastsConfig config) { + Objects.requireNonNull(config); + this.client = new AiHttpClient(config.decisionUrl(), config.token); + } + + public CastsDecisionResponse decision(AiScope scope, AiTrace trace, CastsDecisionRequest payload) { + AiEnvelope env = AiEnvelope.of(scope, trace, payload); + return client.executePayload(env, CastsDecisionResponse.class); + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/CastsVector.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/CastsVector.java new file mode 100644 index 000000000..dde0d1b4c --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/CastsVector.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.index.vector; + +import java.util.Objects; + +/** + * Carrier vector for CASTS traversal. This is not used for vector matching; it is an instruction + * to invoke CASTS decisioning with a goal and depth budget. + */ +public class CastsVector implements IVector { + + private final String goal; + private final int maxDepth; + + public CastsVector(String goal, int maxDepth) { + this.goal = Objects.requireNonNull(goal); + this.maxDepth = maxDepth; + } + + public String getGoal() { + return goal; + } + + public int getMaxDepth() { + return maxDepth; + } + + @Override + public double match(IVector other) { + return 0.0; + } + + @Override + public VectorType getType() { + return VectorType.CastsVector; + } + + @Override + public String toString() { + return "CastsVector{" + + "goal='" + goal + '\'' + + ", maxDepth=" + maxDepth + + '}'; + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/VectorType.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/VectorType.java index 1f076c8e7..c7927fc5f 100644 --- a/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/VectorType.java +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/index/vector/VectorType.java @@ -23,5 +23,6 @@ public enum VectorType { TraversalVector, EmbeddingVector, MagnitudeVector, - KeywordVector + KeywordVector, + CastsVector } diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemConfig.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemConfig.java new file mode 100644 index 000000000..a67fed3d6 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemConfig.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.memory; + +public class LightMemConfig { + + public final String baseUrl; + public final String token; + + public LightMemConfig(String baseUrl, String token) { + this.baseUrl = baseUrl; + this.token = token == null ? "" : token; + } + + public static LightMemConfig fromEnv() { + String url = System.getenv("GEAFLOW_AI_LIGHTMEM_URL"); + if (url == null || url.isEmpty()) { + url = "http://localhost:5002"; + } + String token = System.getenv("GEAFLOW_AI_LIGHTMEM_TOKEN"); + return new LightMemConfig(url, token); + } + + public String writeUrl() { + return join(baseUrl, "/memory/write"); + } + + public String recallUrl() { + return join(baseUrl, "/memory/recall"); + } + + private static String join(String base, String path) { + if (base.endsWith("/") && path.startsWith("/")) { + return base.substring(0, base.length() - 1) + path; + } + if (!base.endsWith("/") && !path.startsWith("/")) { + return base + "/" + path; + } + return base + path; + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemRestClient.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemRestClient.java new file mode 100644 index 000000000..495a2afa9 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemRestClient.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.memory; + +import java.util.Objects; +import org.apache.geaflow.ai.protocol.AiEnvelope; +import org.apache.geaflow.ai.protocol.AiHttpClient; +import org.apache.geaflow.ai.protocol.AiScope; +import org.apache.geaflow.ai.protocol.AiTrace; + +public class LightMemRestClient { + + private final AiHttpClient writeClient; + private final AiHttpClient recallClient; + + public LightMemRestClient(LightMemConfig config) { + Objects.requireNonNull(config); + this.writeClient = new AiHttpClient(config.writeUrl(), config.token); + this.recallClient = new AiHttpClient(config.recallUrl(), config.token); + } + + public MemoryWriteResponse write(AiScope scope, AiTrace trace, MemoryWriteRequest request) { + AiEnvelope env = AiEnvelope.of(scope, trace, request); + return writeClient.executePayload(env, MemoryWriteResponse.class); + } + + public MemoryRecallResponse recall(AiScope scope, AiTrace trace, MemoryRecallRequest request) { + AiEnvelope env = AiEnvelope.of(scope, trace, request); + return recallClient.executePayload(env, MemoryRecallResponse.class); + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryClient.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryClient.java new file mode 100644 index 000000000..7a490c6d8 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryClient.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.memory; + +import java.util.Objects; +import org.apache.geaflow.ai.protocol.AiScope; +import org.apache.geaflow.ai.protocol.AiTrace; + +public class MemoryClient { + + private final LightMemRestClient restClient; + private final String caller; + + public MemoryClient(LightMemRestClient restClient, String caller) { + this.restClient = Objects.requireNonNull(restClient); + this.caller = caller == null || caller.isEmpty() ? "geaflow-ai" : caller; + } + + public MemoryWriteResponse write(AiScope scope, MemoryWriteRequest request) { + return restClient.write(scope, AiTrace.newTrace(caller), request); + } + + public MemoryRecallResponse recall(AiScope scope, MemoryRecallRequest request) { + return restClient.recall(scope, AiTrace.newTrace(caller), request); + } +} diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryRecallRequest.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryRecallRequest.java new file mode 100644 index 000000000..51b40e3e0 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryRecallRequest.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.memory; + +public class MemoryRecallRequest { + + public String query; + + public Integer limit; + + public MemoryRecallRequest() {} + + public MemoryRecallRequest(String query, Integer limit) { + this.query = query; + this.limit = limit; + } +} diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryRecallResponse.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryRecallResponse.java new file mode 100644 index 000000000..267c2dc69 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryRecallResponse.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.memory; + +import com.google.gson.annotations.SerializedName; +import java.util.List; +import java.util.Map; + +public class MemoryRecallResponse { + + public List units; + + public List provenance; + + @SerializedName("trace_id") + public String traceId; + + public static class Unit { + public String id; + public String content; + public Map metadata; + + @SerializedName("last_event_id") + public String lastEventId; + } + + public static class HitProvenance { + @SerializedName("memory_id") + public String memoryId; + + @SerializedName("view_source") + public String viewSource; + + @SerializedName("last_event_id") + public String lastEventId; + + public Double similarity; + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryWriteRequest.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryWriteRequest.java new file mode 100644 index 000000000..33809cad6 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryWriteRequest.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.memory; + +import java.util.ArrayList; +import java.util.List; + +public class MemoryWriteRequest { + + public List messages = new ArrayList<>(); + + /** + * Write mode for the kernel. The current minimal implementation supports: + * - "echo": store user messages directly (no LLM) + */ + public String mode; + + public static class Message { + public String role; + public String content; + + public Message() {} + + public Message(String role, String content) { + this.role = role; + this.content = content; + } + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryWriteResponse.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryWriteResponse.java new file mode 100644 index 000000000..7d0c9dd09 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryWriteResponse.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.memory; + +import com.google.gson.annotations.SerializedName; +import java.util.List; +import java.util.Map; + +public class MemoryWriteResponse { + + public List actions; + + public Map provenance; + + public static class Action { + @SerializedName("event_type") + public String eventType; + + @SerializedName("unit_id") + public String unitId; + + public String content; + + @SerializedName("previous_content") + public String previousContent; + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiEnvelope.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiEnvelope.java new file mode 100644 index 000000000..703f8029d --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiEnvelope.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.protocol; + +import com.google.gson.annotations.SerializedName; + +public class AiEnvelope { + + @SerializedName("api_version") + public String apiVersion = "v1"; + + @SerializedName("scope") + public AiScope scope; + + @SerializedName("trace") + public AiTrace trace; + + @SerializedName("payload") + public Object payload; + + public static AiEnvelope of(AiScope scope, AiTrace trace, Object payload) { + AiEnvelope env = new AiEnvelope(); + env.scope = scope; + env.trace = trace; + env.payload = payload; + return env; + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiError.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiError.java new file mode 100644 index 000000000..7137a455c --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiError.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.protocol; + +public class AiError { + public String code; + public String message; +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiHttpClient.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiHttpClient.java new file mode 100644 index 000000000..4b7f0af22 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiHttpClient.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.protocol; + +import com.google.gson.Gson; +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.TimeUnit; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import org.apache.geaflow.ai.common.config.Constants; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class AiHttpClient { + + private static final Logger LOGGER = LoggerFactory.getLogger(AiHttpClient.class); + private static final Gson GSON = new Gson(); + private static OkHttpClient CLIENT; + + private static final MediaType JSON = MediaType.parse("application/json; charset=utf-8"); + + private final String url; + private final String token; + + public AiHttpClient(String url, String token) { + this.url = Objects.requireNonNull(url); + this.token = token == null ? "" : token; + if (CLIENT == null) { + OkHttpClient.Builder builder = new OkHttpClient.Builder(); + builder.callTimeout(Constants.HTTP_CALL_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.connectTimeout(Constants.HTTP_CONNECT_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.readTimeout(Constants.HTTP_READ_TIMEOUT_SECONDS, TimeUnit.SECONDS); + builder.writeTimeout(Constants.HTTP_WRITE_TIMEOUT_SECONDS, TimeUnit.SECONDS); + CLIENT = builder.build(); + } + } + + public AiResponse execute(AiEnvelope envelope) { + String bodyJson = GSON.toJson(envelope); + RequestBody body = RequestBody.create(JSON, bodyJson); + Request.Builder builder = new Request.Builder().url(url).post(body) + .addHeader("Content-Type", "application/json; charset=utf-8"); + if (!token.isEmpty()) { + builder.addHeader("Authorization", "Bearer " + token); + } + Request request = builder.build(); + + try (okhttp3.Response response = CLIENT.newCall(request).execute()) { + if (response.body() == null) { + throw new RuntimeException("AI response body is null"); + } + String resp = response.body().string(); + if (!response.isSuccessful()) { + LOGGER.warn("AI request failed url={}, code={}, body={}", url, response.code(), resp); + throw new RuntimeException("AI request failed with code=" + response.code()); + } + return GSON.fromJson(resp, AiResponse.class); + } catch (IOException e) { + throw new RuntimeException("AI HTTP call failed", e); + } + } + + public T executePayload(AiEnvelope envelope, Class payloadType) { + AiResponse resp = execute(envelope); + if (!Boolean.TRUE.equals(resp.ok)) { + String err = resp.error != null ? (resp.error.code + ": " + resp.error.message) : "unknown error"; + throw new RuntimeException("AI returned ok=false: " + err); + } + if (resp.payload == null) { + return null; + } + return GSON.fromJson(resp.payload, payloadType); + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiResponse.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiResponse.java new file mode 100644 index 000000000..82d3f096f --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiResponse.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.protocol; + +import com.google.gson.JsonElement; +import com.google.gson.annotations.SerializedName; + +public class AiResponse { + + @SerializedName("ok") + public Boolean ok; + + @SerializedName("api_version") + public String apiVersion; + + @SerializedName("trace") + public AiTrace trace; + + @SerializedName("payload") + public JsonElement payload; + + @SerializedName("error") + public AiError error; +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiScope.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiScope.java new file mode 100644 index 000000000..20a1df54c --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiScope.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.protocol; + +import com.google.gson.annotations.SerializedName; + +public class AiScope { + + @SerializedName("tenant_id") + public String tenantId; + + @SerializedName("user_id") + public String userId; + + @SerializedName("agent_id") + public String agentId; + + @SerializedName("run_id") + public String runId; + + @SerializedName("actor_id") + public String actorId; + + public static AiScope withRunId(String runId) { + AiScope scope = new AiScope(); + scope.runId = runId; + return scope; + } +} + diff --git a/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiTrace.java b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiTrace.java new file mode 100644 index 000000000..9addf52d6 --- /dev/null +++ b/geaflow-ai/src/main/java/org/apache/geaflow/ai/protocol/AiTrace.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.ai.protocol; + +import com.google.gson.annotations.SerializedName; +import java.util.UUID; + +public class AiTrace { + + @SerializedName("trace_id") + public String traceId; + + @SerializedName("timestamp") + public Double timestamp; + + @SerializedName("caller") + public String caller; + + public static AiTrace newTrace(String caller) { + AiTrace trace = new AiTrace(); + trace.traceId = "tr_" + UUID.randomUUID().toString().replace("-", ""); + trace.timestamp = System.currentTimeMillis() / 1000.0; + trace.caller = caller; + return trace; + } +} + From b96d4767c461002b888b8a6976f323eb5d37affa Mon Sep 17 00:00:00 2001 From: appointat <65004114+Appointat@users.noreply.github.com> Date: Tue, 3 Mar 2026 11:48:35 +0800 Subject: [PATCH 2/4] feat(tests): add license header validation tests for casts and lightmem plugins --- .github/workflows/ci-py311.yml | 38 ++++++++++ .../casts/tests/test_license_headers.py | 69 +++++++++++++++++++ .../lightmem/tests/test_license_headers.py | 65 +++++++++++++++++ 3 files changed, 172 insertions(+) create mode 100644 geaflow-ai/plugins/casts/tests/test_license_headers.py create mode 100644 geaflow-ai/plugins/lightmem/tests/test_license_headers.py diff --git a/.github/workflows/ci-py311.yml b/.github/workflows/ci-py311.yml index bece3d264..fb03e3c60 100644 --- a/.github/workflows/ci-py311.yml +++ b/.github/workflows/ci-py311.yml @@ -91,3 +91,41 @@ jobs: - name: Run tests run: | uv run --no-editable pytest -q + + smoke: + name: smoke-${{ matrix.plugin }} (py${{ matrix.python-version }}) + runs-on: ubuntu-latest + needs: tests + strategy: + fail-fast: false + matrix: + python-version: ["3.11"] + plugin: ["casts", "lightmem"] + + defaults: + run: + shell: bash + working-directory: geaflow-ai/plugins/${{ matrix.plugin }} + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Setup uv + uses: astral-sh/setup-uv@v4 + with: + version: "0.9.17" + enable-cache: true + + - name: Prepare venv + run: | + python -m venv .venv + + - name: Smoke test + run: | + bash ./scripts/smoke.sh diff --git a/geaflow-ai/plugins/casts/tests/test_license_headers.py b/geaflow-ai/plugins/casts/tests/test_license_headers.py new file mode 100644 index 000000000..ac713b530 --- /dev/null +++ b/geaflow-ai/plugins/casts/tests/test_license_headers.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pathlib import Path + +LICENSE_HEADER_TOKEN = "Licensed to the Apache Software Foundation (ASF) under one" + +SKIP_DIRS = { + ".git", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".venv", + "__pycache__", + "build", + "dist", +} + + +def _python_files_under(root: Path) -> list[Path]: + python_files: list[Path] = [] + for path in root.rglob("*.py"): + if any(part in SKIP_DIRS for part in path.parts): + continue + python_files.append(path) + return python_files + + +def _has_license_header(path: Path) -> bool: + # Be robust to: + # - shebangs (`#!/usr/bin/env python3`) + # - encoding lines (`# -*- coding: utf-8 -*-`) + # - minor whitespace differences (`# Licensed...` vs `# Licensed...`) + try: + head_lines = path.read_text(encoding="utf-8", errors="replace").splitlines()[:30] + except OSError: + return False + return LICENSE_HEADER_TOKEN in "\n".join(head_lines) + + +def test_all_python_files_have_license_header() -> None: + plugin_root = Path(__file__).resolve().parents[1] + + missing = sorted( + (path.relative_to(plugin_root) for path in _python_files_under(plugin_root) if not _has_license_header(path)), + key=lambda p: str(p), + ) + + assert not missing, ( + "Missing ASF license header in the following Python files (expected within the first 30 lines):\n" + + "\n".join(f"- {path}" for path in missing) + ) + diff --git a/geaflow-ai/plugins/lightmem/tests/test_license_headers.py b/geaflow-ai/plugins/lightmem/tests/test_license_headers.py new file mode 100644 index 000000000..d691ae3ea --- /dev/null +++ b/geaflow-ai/plugins/lightmem/tests/test_license_headers.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from pathlib import Path + +LICENSE_HEADER_TOKEN = "Licensed to the Apache Software Foundation (ASF) under one" + +SKIP_DIRS = { + ".git", + ".mypy_cache", + ".pytest_cache", + ".ruff_cache", + ".venv", + "__pycache__", + "build", + "dist", +} + + +def _python_files_under(root: Path) -> list[Path]: + python_files: list[Path] = [] + for path in root.rglob("*.py"): + if any(part in SKIP_DIRS for part in path.parts): + continue + python_files.append(path) + return python_files + + +def _has_license_header(path: Path) -> bool: + try: + head_lines = path.read_text(encoding="utf-8", errors="replace").splitlines()[:30] + except OSError: + return False + return LICENSE_HEADER_TOKEN in "\n".join(head_lines) + + +def test_all_python_files_have_license_header() -> None: + plugin_root = Path(__file__).resolve().parents[1] + + missing = sorted( + (path.relative_to(plugin_root) for path in _python_files_under(plugin_root) if not _has_license_header(path)), + key=lambda p: str(p), + ) + + assert not missing, ( + "Missing ASF license header in the following Python files (expected within the first 30 lines):\n" + + "\n".join(f"- {path}" for path in missing) + ) + From 3e19c8fed20e096bff6f20b89d8ca0c0321564a5 Mon Sep 17 00:00:00 2001 From: appointat <65004114+Appointat@users.noreply.github.com> Date: Tue, 3 Mar 2026 11:48:45 +0800 Subject: [PATCH 3/4] feat: add README for LightMem plugin with architecture and usage details --- geaflow-ai/plugins/lightmem/README.md | 191 ++++++++++++++++++++++++++ 1 file changed, 191 insertions(+) create mode 100644 geaflow-ai/plugins/lightmem/README.md diff --git a/geaflow-ai/plugins/lightmem/README.md b/geaflow-ai/plugins/lightmem/README.md new file mode 100644 index 000000000..fff7c880d --- /dev/null +++ b/geaflow-ai/plugins/lightmem/README.md @@ -0,0 +1,191 @@ +# LightMem (Lightweight Context Memory) + +LightMem is a ledger-backed memory kernel for AI agents. It stores, retrieves, +and traces context memories through an append-only event ledger and vector-indexed +views, with strict scope isolation and full provenance on every recall. + +## Name Origin + +**LightMem** stands for **Lightweight Context Memory** — a minimal but complete +memory system that stays lightweight enough for local dev and deterministic testing, +while following a protocol designed for production use. + +## Design Goals + +- **Ledger-centric**: An append-only event log is the single source of truth for + all memory writes, updates, and deletions. +- **Scope-bounded isolation**: Every read and write is bound to a scope + (tenant/user/agent/run/actor). No cross-scope memory leakage. +- **Provenance-first**: Every recalled memory carries an evidence chain back to + the ledger event that created it (injection closure, evidence closure, decision + closure). +- **Minimal viable kernel**: Ships with hash-based embeddings for deterministic + testing; designed to be pluggable for production embedding backends. + +## Two Modes: Production vs. Echo + +### Production Mode (GeaFlow Java Integration) + +In production, the **GeaFlow Java data plane** acts as a proxy. The Java +`GeaFlowMemoryServer` accepts `/memory/write` and `/memory/recall` requests, +extracts scope from query parameters, and forwards them to the Python LightMem +service via `LightMemRestClient`. + +Key integration files: + +- Java client: `geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/LightMemRestClient.java` +- Java high-level API: `geaflow-ai/src/main/java/org/apache/geaflow/ai/memory/MemoryClient.java` +- Java server: `geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java` +- Python service: `geaflow-ai/plugins/lightmem/api/app.py` + +### Echo Mode (Smoke Testing) + +Setting `mode="echo"` in a write request extracts user messages directly as +memory units without invoking an LLM. This is the default mode, intended for +local development, CI, and integration smoke tests. + +A full LLM-driven pipeline (`mode="llm"`) following the 5-step write protocol +(Scope → Unitize → Ground → Decide → Commit) is defined in `AGENTS.md` and +is a planned extension. + +## Architecture + +### Write Pipeline + +``` +POST /memory/write (Envelope) + │ + ├─ Validate scope (reject if empty) + ├─ Unitize: extract user messages → Action(ADD) list + ├─ Commit: for each action: + │ ├─ ledger.append(scope, action) → LedgerEvent + │ └─ views.vector_view.upsert(unit) + └─ Return actions + provenance (event_ids) +``` + +### Recall Pipeline + +``` +POST /memory/recall (Envelope) + │ + ├─ Validate scope (reject if empty) + ├─ views.vector_view.recall(query, scope_filter, limit) + │ ├─ Embed query + │ ├─ Cosine-similarity search + │ ├─ Filter by scope (strict match) + │ └─ Return top-k hits + └─ Return units + provenance (memory_id, view_source, last_event_id, similarity) +``` + +## Module Layout + +- `core/`: Memory kernel — `MemoryKernel`, `Ledger`, `VectorView`, core data types +- `api/`: FastAPI HTTP service — endpoints, envelope schema, error handling +- `tests/`: Unit and integration tests +- `scripts/`: `smoke.sh` one-click integration test + +## API Endpoints + +All requests (except `/health`) use a standard JSON envelope: + +```json +{ + "api_version": "v1", + "scope": { "user_id": "u1", "agent_id": "a1" }, + "trace": { "trace_id": "tr_...", "timestamp": 1700000000, "caller": "my-app" }, + "payload": { ... } +} +``` + +| Endpoint | Method | Payload | Description | +|---|---|---|---| +| `/health` | GET | — | Returns `{"status": "UP"}` | +| `/memory/write` | POST | `messages`, `mode` | Write memories from messages | +| `/memory/recall` | POST | `query`, `limit` | Recall memories by similarity | + +Scope must contain at least one non-empty field (`tenant_id`, `user_id`, +`agent_id`, `run_id`, or `actor_id`). Requests with an empty scope are +rejected with HTTP 422. + +## Configuration + +Java-side environment variables (for the GeaFlow proxy): + +- `GEAFLOW_AI_LIGHTMEM_URL` — LightMem service URL (default: `http://localhost:5002`) +- `GEAFLOW_AI_LIGHTMEM_TOKEN` — Optional bearer token for authentication + +Smoke-test environment variables: + +- `GEAFLOW_AI_LIGHTMEM_HOST` — Host to bind (default: `127.0.0.1`) +- `GEAFLOW_AI_LIGHTMEM_PORT` — Port to bind (default: `5002`) +- `GEAFLOW_AI_LIGHTMEM_TIMEOUT_SECONDS` — Server startup timeout (default: `20`) + +## Local Dev (Python 3.11 + uv) + +Each Python plugin keeps its own virtual environment at `.venv/` (gitignored). + +One-time venv creation: + +```bash +cd geaflow-ai/plugins/lightmem +[ -d .venv ] || python3.11 -m venv .venv +``` + +Sync dependencies: + +```bash +cd geaflow-ai/plugins/lightmem +uv sync --extra dev +``` + +Notes: + +- You don't need to `source .venv/bin/activate` for normal workflows. + - `uv sync` installs into the project env (`.venv/`) + - `uv run ...` executes inside that env +- If you *do* activate a venv for interactive work, `uv sync --active` forces syncing + into the active environment. + +## Running the Service (FastAPI) + +Start the LightMem service: + +```bash +cd geaflow-ai/plugins/lightmem +uv sync --extra dev +uv run uvicorn api.app:app --host 127.0.0.1 --port 5002 +``` + +One-click smoke test (starts/stops the server as needed): + +```bash +cd geaflow-ai/plugins/lightmem +./scripts/smoke.sh +``` + +## Tests + +Run tests locally: + +```bash +cd geaflow-ai/plugins/lightmem +uv sync --extra dev +uv run pytest -q tests +``` + +## Lint & Type Check + +Run lint (ruff) and type checks (mypy): + +```bash +cd geaflow-ai/plugins/lightmem +uv sync --extra dev +uv run ruff format --check . +uv run ruff check . +# LightMem uses package-dir mapping; mypy needs a non-editable install. +uv sync --extra dev --no-editable +uv run --no-editable mypy -p lightmem -p api +``` + +CI: `.github/workflows/ci-py311.yml` runs LightMem + CASTS Python tests on +Python 3.11 via `uv`. From 47459a3d5b5c8ae342725c45a1db61ffe21b9e54 Mon Sep 17 00:00:00 2001 From: appointat <65004114+Appointat@users.noreply.github.com> Date: Tue, 3 Mar 2026 14:28:23 +0800 Subject: [PATCH 4/4] feat: remove AGENTS.md file as part of plugin documentation cleanup --- geaflow-ai/plugins/casts/AGENTS.md | 71 ------------------------------ 1 file changed, 71 deletions(-) delete mode 100644 geaflow-ai/plugins/casts/AGENTS.md diff --git a/geaflow-ai/plugins/casts/AGENTS.md b/geaflow-ai/plugins/casts/AGENTS.md deleted file mode 100644 index 9de3b8898..000000000 --- a/geaflow-ai/plugins/casts/AGENTS.md +++ /dev/null @@ -1,71 +0,0 @@ -# CASTS Agent Instructions (geaflow-ai/plugins/casts) - -This file defines CASTS plugin-local instructions for coding agents. - -## Must-Read (Before You Change Code) - -- Review and follow: `geaflow-ai/plugins/CODE_STYLES.md` - - Treat it as the baseline contract for changes under `geaflow-ai/plugins/casts/`. - - If you need to break a rule, document the reason in the PR/commit message. - -## Repository Layout (What Goes Where) - -- `core/`: deterministic cache + decision logic (no network calls required). -- `services/`: integration code (LLM / embedding / external I/O). -- `harness/`: offline simulation harness (data + executor + evaluator). -- `api/`: production-facing decision service (FastAPI). - - Endpoint: `POST /casts/decision` - - Safety: must degrade conservatively to `decision="stop"` on invalid input or upstream failures. -- `scripts/`: local developer scripts (e.g., smoke tests). -- `tests/`: pytest suite. - -## Local Dev (Python 3.11 + uv) - -We use a per-plugin venv in `.venv/` (gitignored) and a **no-activate** workflow. - -One-time setup: - -```bash -cd geaflow-ai/plugins/casts -[ -d .venv ] || python3.11 -m venv .venv -uv sync --extra dev -``` - -Run tests: - -```bash -cd geaflow-ai/plugins/casts -uv run pytest -q -``` - -Run lint + type checks: - -```bash -cd geaflow-ai/plugins/casts -uv run ruff format --check . -uv run ruff check . -uv run mypy -p api -p core -p services -p harness -``` - -## Run The Service (FastAPI) - -```bash -cd geaflow-ai/plugins/casts -uv sync --extra dev -uv run uvicorn api.app:app --host 127.0.0.1 --port 5001 -``` - -One-click smoke: - -```bash -cd geaflow-ai/plugins/casts -./scripts/smoke.sh -``` - -## Safety Defaults - -- Never enable evaluating LLM-provided predicates in production: - - `LLM_ORACLE_ENABLE_PREDICATE_EVAL` must remain `False` by default. -- `scope` is a hard boundary: - - requests with empty scope must be rejected (or conservatively downgraded). - - CASTS service additionally requires `scope.run_id` for cache isolation; if missing, it downgrades to `stop`.