Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
373 changes: 190 additions & 183 deletions geaflow-ai/src/main/java/org/apache/geaflow/ai/GeaFlowMemoryServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;

import org.apache.geaflow.ai.common.util.SeDeUtil;
import org.apache.geaflow.ai.graph.*;
import org.apache.geaflow.ai.graph.io.*;
Expand All @@ -40,193 +41,199 @@
@Controller
public class GeaFlowMemoryServer {

private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMemoryServer.class);

private static final String SERVER_NAME = "geaflow-memory-server";
private static final int DEFAULT_PORT = 8080;

private static final ServerMemoryCache CACHE = new ServerMemoryCache();

public static void main(String[] args) {
System.setProperty("solon.app.name", SERVER_NAME);
Solon.start(GeaFlowMemoryServer.class, args, app -> {
app.cfg().loadAdd("application.yml");
int port = app.cfg().getInt("server.port", DEFAULT_PORT);
LOGGER.info("Starting {} on port {}", SERVER_NAME, port);
app.get("/", ctx -> {
private static final Logger LOGGER = LoggerFactory.getLogger(GeaFlowMemoryServer.class);

private static final String SERVER_NAME = "geaflow-memory-server";
private static final int DEFAULT_PORT = 8080;

private static final ServerMemoryCache CACHE = new ServerMemoryCache();

public static void main(String[] args) {
System.setProperty("solon.app.name", SERVER_NAME);
Solon.start(
GeaFlowMemoryServer.class,
args,
app -> {
app.cfg().loadAdd("application.yml");
int port = app.cfg().getInt("server.port", DEFAULT_PORT);
LOGGER.info("Starting {} on port {}", SERVER_NAME, port);
app.get(
"/",
ctx -> {
ctx.output("GeaFlow AI Server is running...");
});
app.get("/health", ctx -> {
});
app.get(
"/health",
ctx -> {
ctx.output("{\"status\":\"UP\",\"service\":\"" + SERVER_NAME + "\"}");
});
});
});
}

@Get
@Mapping("/api/test")
public String test() {
return "GeaFlow Memory Server is working!";
}

@Post
@Mapping("/graph/create")
public String createGraph(@Body String input) {
GraphSchema graphSchema = SeDeUtil.deserializeGraphSchema(input);
String graphName = graphSchema.getName();
if (graphName == null || CACHE.getGraphByName(graphName) != null) {
throw new RuntimeException("Cannot create graph name: " + graphName);
}

@Get
@Mapping("/api/test")
public String test() {
return "GeaFlow Memory Server is working!";
}

@Post
@Mapping("/graph/create")
public String createGraph(@Body String input) {
GraphSchema graphSchema = SeDeUtil.deserializeGraphSchema(input);
String graphName = graphSchema.getName();
if (graphName == null || CACHE.getGraphByName(graphName) != null) {
throw new RuntimeException("Cannot create graph name: " + graphName);
}
Map<String, EntityGroup> entities = new HashMap<>();
for (VertexSchema vertexSchema : graphSchema.getVertexSchemaList()) {
entities.put(vertexSchema.getName(), new VertexGroup(vertexSchema, new ArrayList<>()));
}
for (EdgeSchema edgeSchema : graphSchema.getEdgeSchemaList()) {
entities.put(edgeSchema.getName(), new EdgeGroup(edgeSchema, new ArrayList<>()));
}
MemoryGraph graph = new MemoryGraph(graphSchema, entities);
CACHE.putGraph(graph);
LocalMemoryGraphAccessor graphAccessor = new LocalMemoryGraphAccessor(graph);
LOGGER.info("Success to init empty graph.");

EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore();
indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor));
LOGGER.info("Success to init EntityAttributeIndexStore.");

GraphMemoryServer server = new GraphMemoryServer();
server.addGraphAccessor(graphAccessor);
server.addIndexStore(indexStore);
LOGGER.info("Success to init GraphMemoryServer.");
CACHE.putServer(server);

LOGGER.info("Success to init graph. SCHEMA: {}", graphSchema);
return "createGraph has been called, graphName: " + graphName;
}

@Post
@Mapping("/graph/addEntitySchema")
public String addSchema(@Param("graphName") String graphName,
@Body String input) {
Graph graph = CACHE.getGraphByName(graphName);
if (graph == null) {
throw new RuntimeException("Graph not exist.");
}
if (!(graph instanceof MemoryGraph)) {
throw new RuntimeException("Graph cannot modify.");
}
MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph);
Schema schema = SeDeUtil.deserializeEntitySchema(input);
String schemaName = schema.getName();
if (schema instanceof VertexSchema) {
memoryMutableGraph.addVertexSchema((VertexSchema) schema);
} else if (schema instanceof EdgeSchema) {
memoryMutableGraph.addEdgeSchema((EdgeSchema) schema);
} else {
throw new RuntimeException("Cannt add schema: " + input);
}
return "addSchema has been called, schemaName: " + schemaName;
}

@Post
@Mapping("/graph/getGraphSchema")
public String getSchema(@Param("graphName") String graphName) {
Graph graph = CACHE.getGraphByName(graphName);
if (graph == null) {
throw new RuntimeException("Graph not exist.");
}
if (!(graph instanceof MemoryGraph)) {
throw new RuntimeException("Graph cannot modify.");
}
return SeDeUtil.serializeGraphSchema(graph.getGraphSchema());
}

@Post
@Mapping("/graph/insertEntity")
public String addEntity(@Param("graphName") String graphName,
@Body String input) {
Graph graph = CACHE.getGraphByName(graphName);
if (graph == null) {
throw new RuntimeException("Graph not exist.");
}
if (!(graph instanceof MemoryGraph)) {
throw new RuntimeException("Graph cannot modify.");
}
MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph);
List<GraphEntity> graphEntities = SeDeUtil.deserializeEntities(input);

for (GraphEntity entity : graphEntities) {
if (entity instanceof GraphVertex) {
memoryMutableGraph.addVertex(((GraphVertex) entity).getVertex());
} else {
memoryMutableGraph.addEdge(((GraphEdge) entity).getEdge());
}
}
CACHE.getConsolidateServer().executeConsolidateTask(
Map<String, EntityGroup> entities = new HashMap<>();
for (VertexSchema vertexSchema : graphSchema.getVertexSchemaList()) {
entities.put(vertexSchema.getName(), new VertexGroup(vertexSchema, new ArrayList<>()));
}
for (EdgeSchema edgeSchema : graphSchema.getEdgeSchemaList()) {
entities.put(edgeSchema.getName(), new EdgeGroup(edgeSchema, new ArrayList<>()));
}
MemoryGraph graph = new MemoryGraph(graphSchema, entities);
CACHE.putGraph(graph);
LocalMemoryGraphAccessor graphAccessor = new LocalMemoryGraphAccessor(graph);
LOGGER.info("Success to init empty graph.");

EntityAttributeIndexStore indexStore = new EntityAttributeIndexStore();
indexStore.initStore(new SubgraphSemanticPromptFunction(graphAccessor));
LOGGER.info("Success to init EntityAttributeIndexStore.");

GraphMemoryServer server = new GraphMemoryServer();
server.addGraphAccessor(graphAccessor);
server.addIndexStore(indexStore);
LOGGER.info("Success to init GraphMemoryServer.");
CACHE.putServer(server);

LOGGER.info("Success to init graph. SCHEMA: {}", graphSchema);
return "createGraph has been called, graphName: " + graphName;
}

@Post
@Mapping("/graph/addEntitySchema")
public String addSchema(@Param("graphName") String graphName, @Body String input) {
Graph graph = CACHE.getGraphByName(graphName);
if (graph == null) {
throw new RuntimeException("Graph not exist.");
}
if (!(graph instanceof MemoryGraph)) {
throw new RuntimeException("Graph cannot modify.");
}
MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph);
Schema schema = SeDeUtil.deserializeEntitySchema(input);
String schemaName = schema.getName();
if (schema instanceof VertexSchema) {
memoryMutableGraph.addVertexSchema((VertexSchema) schema);
} else if (schema instanceof EdgeSchema) {
memoryMutableGraph.addEdgeSchema((EdgeSchema) schema);
} else {
throw new RuntimeException("Cannt add schema: " + input);
}
return "addSchema has been called, schemaName: " + schemaName;
}

@Post
@Mapping("/graph/getGraphSchema")
public String getSchema(@Param("graphName") String graphName) {
Graph graph = CACHE.getGraphByName(graphName);
if (graph == null) {
throw new RuntimeException("Graph not exist.");
}
if (!(graph instanceof MemoryGraph)) {
throw new RuntimeException("Graph cannot modify.");
}
return SeDeUtil.serializeGraphSchema(graph.getGraphSchema());
}

@Post
@Mapping("/graph/insertEntity")
public String addEntity(@Param("graphName") String graphName, @Body String input) {
Graph graph = CACHE.getGraphByName(graphName);
if (graph == null) {
throw new RuntimeException("Graph not exist.");
}
if (!(graph instanceof MemoryGraph)) {
throw new RuntimeException("Graph cannot modify.");
}
MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph);
List<GraphEntity> graphEntities = SeDeUtil.deserializeEntities(input);

for (GraphEntity entity : graphEntities) {
if (entity instanceof GraphVertex) {
memoryMutableGraph.addVertex(((GraphVertex) entity).getVertex());
} else {
memoryMutableGraph.addEdge(((GraphEdge) entity).getEdge());
}
}
CACHE
.getConsolidateServer()
.executeConsolidateTask(
CACHE.getServerByName(graphName).getGraphAccessors().get(0), memoryMutableGraph);
return "Success to add entities, num: " + graphEntities.size();
}

@Post
@Mapping("/graph/delEntity")
public String deleteEntity(@Param("graphName") String graphName,
@Body String input) {
Graph graph = CACHE.getGraphByName(graphName);
if (graph == null) {
throw new RuntimeException("Graph not exist.");
}
if (!(graph instanceof MemoryGraph)) {
throw new RuntimeException("Graph cannot modify.");
}
MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph);
List<GraphEntity> graphEntities = SeDeUtil.deserializeEntities(input);
for (GraphEntity entity : graphEntities) {
if (entity instanceof GraphVertex) {
memoryMutableGraph.removeVertex(entity.getLabel(),
((GraphVertex) entity).getVertex().getId());
} else {
memoryMutableGraph.removeEdge(((GraphEdge) entity).getEdge());
}
}
return "Success to remove entities, num: " + graphEntities.size();
}

@Post
@Mapping("/query/context")
public String createContext(@Param("graphName") String graphName) {
GraphMemoryServer server = CACHE.getServerByName(graphName);
if (server == null) {
throw new RuntimeException("Server not exist.");
}
String sessionId = server.createSession();
CACHE.putSession(server, sessionId);
return sessionId;
}

@Post
@Mapping("/query/exec")
public String execQuery(@Param("sessionId") String sessionId,
@Body String query) {
String graphName = CACHE.getGraphNameBySession(sessionId);
if (graphName == null) {
throw new RuntimeException("Graph not exist.");
}
GraphMemoryServer server = CACHE.getServerByName(graphName);
VectorSearch search = new VectorSearch(null, sessionId);
search.addVector(new KeywordVector(query));
server.search(search);
Context context = server.verbalize(sessionId,
new SubgraphSemanticPromptFunction(server.getGraphAccessors().get(0)));
return context.toString();
}

@Post
@Mapping("/query/result")
public String getResult(@Param("sessionId") String sessionId) {
String graphName = CACHE.getGraphNameBySession(sessionId);
if (graphName == null) {
throw new RuntimeException("Graph not exist.");
}
GraphMemoryServer server = CACHE.getServerByName(graphName);
List<GraphEntity> result = server.getSessionEntities(sessionId);
return result.toString();
return "Success to add entities, num: " + graphEntities.size();
}

@Post
@Mapping("/graph/delEntity")
public String deleteEntity(@Param("graphName") String graphName, @Body String input) {
Graph graph = CACHE.getGraphByName(graphName);
if (graph == null) {
throw new RuntimeException("Graph not exist.");
}
if (!(graph instanceof MemoryGraph)) {
throw new RuntimeException("Graph cannot modify.");
}
MemoryMutableGraph memoryMutableGraph = new MemoryMutableGraph((MemoryGraph) graph);
List<GraphEntity> graphEntities = SeDeUtil.deserializeEntities(input);
for (GraphEntity entity : graphEntities) {
if (entity instanceof GraphVertex) {
memoryMutableGraph.removeVertex(
entity.getLabel(), ((GraphVertex) entity).getVertex().getId());
} else {
memoryMutableGraph.removeEdge(((GraphEdge) entity).getEdge());
}
}
return "Success to remove entities, num: " + graphEntities.size();
}

@Post
@Mapping("/query/context")
public String createContext(@Param("graphName") String graphName) {
GraphMemoryServer server = CACHE.getServerByName(graphName);
if (server == null) {
throw new RuntimeException("Server not exist.");
}
String sessionId = server.createSession();
CACHE.putSession(server, sessionId);
return sessionId;
}

@Post
@Mapping("/query/exec")
public String execQuery(@Param("sessionId") String sessionId, @Body String query) {
String graphName = CACHE.getGraphNameBySession(sessionId);
if (graphName == null) {
throw new RuntimeException("Graph not exist.");
}
GraphMemoryServer server = CACHE.getServerByName(graphName);
VectorSearch search = new VectorSearch(null, sessionId);
search.addVector(new KeywordVector(query));
server.search(search);
Context context =
server.verbalize(
sessionId, new SubgraphSemanticPromptFunction(server.getGraphAccessors().get(0)));
return context.toString();
}

@Post
@Mapping("/query/result")
public String getResult(@Param("sessionId") String sessionId) {
String graphName = CACHE.getGraphNameBySession(sessionId);
if (graphName == null) {
throw new RuntimeException("Graph not exist.");
}
GraphMemoryServer server = CACHE.getServerByName(graphName);
List<GraphEntity> result = server.getSessionEntities(sessionId);
return result.toString();
}
}
Loading